nvidia-nat 1.2.0rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/agent/__init__.py +0 -0
- aiq/agent/base.py +239 -0
- aiq/agent/dual_node.py +67 -0
- aiq/agent/react_agent/__init__.py +0 -0
- aiq/agent/react_agent/agent.py +355 -0
- aiq/agent/react_agent/output_parser.py +104 -0
- aiq/agent/react_agent/prompt.py +41 -0
- aiq/agent/react_agent/register.py +149 -0
- aiq/agent/reasoning_agent/__init__.py +0 -0
- aiq/agent/reasoning_agent/reasoning_agent.py +225 -0
- aiq/agent/register.py +23 -0
- aiq/agent/rewoo_agent/__init__.py +0 -0
- aiq/agent/rewoo_agent/agent.py +411 -0
- aiq/agent/rewoo_agent/prompt.py +108 -0
- aiq/agent/rewoo_agent/register.py +158 -0
- aiq/agent/tool_calling_agent/__init__.py +0 -0
- aiq/agent/tool_calling_agent/agent.py +119 -0
- aiq/agent/tool_calling_agent/register.py +106 -0
- aiq/authentication/__init__.py +14 -0
- aiq/authentication/api_key/__init__.py +14 -0
- aiq/authentication/api_key/api_key_auth_provider.py +96 -0
- aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
- aiq/authentication/api_key/register.py +26 -0
- aiq/authentication/exceptions/__init__.py +14 -0
- aiq/authentication/exceptions/api_key_exceptions.py +38 -0
- aiq/authentication/http_basic_auth/__init__.py +0 -0
- aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- aiq/authentication/http_basic_auth/register.py +30 -0
- aiq/authentication/interfaces.py +93 -0
- aiq/authentication/oauth2/__init__.py +14 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- aiq/authentication/oauth2/register.py +25 -0
- aiq/authentication/register.py +21 -0
- aiq/builder/__init__.py +0 -0
- aiq/builder/builder.py +285 -0
- aiq/builder/component_utils.py +316 -0
- aiq/builder/context.py +264 -0
- aiq/builder/embedder.py +24 -0
- aiq/builder/eval_builder.py +161 -0
- aiq/builder/evaluator.py +29 -0
- aiq/builder/framework_enum.py +24 -0
- aiq/builder/front_end.py +73 -0
- aiq/builder/function.py +344 -0
- aiq/builder/function_base.py +380 -0
- aiq/builder/function_info.py +627 -0
- aiq/builder/intermediate_step_manager.py +174 -0
- aiq/builder/llm.py +25 -0
- aiq/builder/retriever.py +25 -0
- aiq/builder/user_interaction_manager.py +74 -0
- aiq/builder/workflow.py +148 -0
- aiq/builder/workflow_builder.py +1117 -0
- aiq/cli/__init__.py +14 -0
- aiq/cli/cli_utils/__init__.py +0 -0
- aiq/cli/cli_utils/config_override.py +231 -0
- aiq/cli/cli_utils/validation.py +37 -0
- aiq/cli/commands/__init__.py +0 -0
- aiq/cli/commands/configure/__init__.py +0 -0
- aiq/cli/commands/configure/channel/__init__.py +0 -0
- aiq/cli/commands/configure/channel/add.py +28 -0
- aiq/cli/commands/configure/channel/channel.py +36 -0
- aiq/cli/commands/configure/channel/remove.py +30 -0
- aiq/cli/commands/configure/channel/update.py +30 -0
- aiq/cli/commands/configure/configure.py +33 -0
- aiq/cli/commands/evaluate.py +139 -0
- aiq/cli/commands/info/__init__.py +14 -0
- aiq/cli/commands/info/info.py +39 -0
- aiq/cli/commands/info/list_channels.py +32 -0
- aiq/cli/commands/info/list_components.py +129 -0
- aiq/cli/commands/info/list_mcp.py +213 -0
- aiq/cli/commands/registry/__init__.py +14 -0
- aiq/cli/commands/registry/publish.py +88 -0
- aiq/cli/commands/registry/pull.py +118 -0
- aiq/cli/commands/registry/registry.py +38 -0
- aiq/cli/commands/registry/remove.py +108 -0
- aiq/cli/commands/registry/search.py +155 -0
- aiq/cli/commands/sizing/__init__.py +14 -0
- aiq/cli/commands/sizing/calc.py +297 -0
- aiq/cli/commands/sizing/sizing.py +27 -0
- aiq/cli/commands/start.py +246 -0
- aiq/cli/commands/uninstall.py +81 -0
- aiq/cli/commands/validate.py +47 -0
- aiq/cli/commands/workflow/__init__.py +14 -0
- aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
- aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
- aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- aiq/cli/commands/workflow/workflow.py +37 -0
- aiq/cli/commands/workflow/workflow_commands.py +313 -0
- aiq/cli/entrypoint.py +135 -0
- aiq/cli/main.py +44 -0
- aiq/cli/register_workflow.py +488 -0
- aiq/cli/type_registry.py +1000 -0
- aiq/data_models/__init__.py +14 -0
- aiq/data_models/api_server.py +694 -0
- aiq/data_models/authentication.py +231 -0
- aiq/data_models/common.py +171 -0
- aiq/data_models/component.py +54 -0
- aiq/data_models/component_ref.py +168 -0
- aiq/data_models/config.py +406 -0
- aiq/data_models/dataset_handler.py +123 -0
- aiq/data_models/discovery_metadata.py +335 -0
- aiq/data_models/embedder.py +27 -0
- aiq/data_models/evaluate.py +127 -0
- aiq/data_models/evaluator.py +26 -0
- aiq/data_models/front_end.py +26 -0
- aiq/data_models/function.py +30 -0
- aiq/data_models/function_dependencies.py +72 -0
- aiq/data_models/interactive.py +246 -0
- aiq/data_models/intermediate_step.py +302 -0
- aiq/data_models/invocation_node.py +38 -0
- aiq/data_models/llm.py +27 -0
- aiq/data_models/logging.py +26 -0
- aiq/data_models/memory.py +27 -0
- aiq/data_models/object_store.py +44 -0
- aiq/data_models/profiler.py +54 -0
- aiq/data_models/registry_handler.py +26 -0
- aiq/data_models/retriever.py +30 -0
- aiq/data_models/retry_mixin.py +35 -0
- aiq/data_models/span.py +187 -0
- aiq/data_models/step_adaptor.py +64 -0
- aiq/data_models/streaming.py +33 -0
- aiq/data_models/swe_bench_model.py +54 -0
- aiq/data_models/telemetry_exporter.py +26 -0
- aiq/data_models/ttc_strategy.py +30 -0
- aiq/embedder/__init__.py +0 -0
- aiq/embedder/langchain_client.py +41 -0
- aiq/embedder/nim_embedder.py +59 -0
- aiq/embedder/openai_embedder.py +43 -0
- aiq/embedder/register.py +24 -0
- aiq/eval/__init__.py +14 -0
- aiq/eval/config.py +60 -0
- aiq/eval/dataset_handler/__init__.py +0 -0
- aiq/eval/dataset_handler/dataset_downloader.py +106 -0
- aiq/eval/dataset_handler/dataset_filter.py +52 -0
- aiq/eval/dataset_handler/dataset_handler.py +254 -0
- aiq/eval/evaluate.py +506 -0
- aiq/eval/evaluator/__init__.py +14 -0
- aiq/eval/evaluator/base_evaluator.py +73 -0
- aiq/eval/evaluator/evaluator_model.py +45 -0
- aiq/eval/intermediate_step_adapter.py +99 -0
- aiq/eval/rag_evaluator/__init__.py +0 -0
- aiq/eval/rag_evaluator/evaluate.py +178 -0
- aiq/eval/rag_evaluator/register.py +143 -0
- aiq/eval/register.py +23 -0
- aiq/eval/remote_workflow.py +133 -0
- aiq/eval/runners/__init__.py +14 -0
- aiq/eval/runners/config.py +39 -0
- aiq/eval/runners/multi_eval_runner.py +54 -0
- aiq/eval/runtime_event_subscriber.py +52 -0
- aiq/eval/swe_bench_evaluator/__init__.py +0 -0
- aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
- aiq/eval/swe_bench_evaluator/register.py +36 -0
- aiq/eval/trajectory_evaluator/__init__.py +0 -0
- aiq/eval/trajectory_evaluator/evaluate.py +75 -0
- aiq/eval/trajectory_evaluator/register.py +40 -0
- aiq/eval/tunable_rag_evaluator/__init__.py +0 -0
- aiq/eval/tunable_rag_evaluator/evaluate.py +245 -0
- aiq/eval/tunable_rag_evaluator/register.py +52 -0
- aiq/eval/usage_stats.py +41 -0
- aiq/eval/utils/__init__.py +0 -0
- aiq/eval/utils/output_uploader.py +140 -0
- aiq/eval/utils/tqdm_position_registry.py +40 -0
- aiq/eval/utils/weave_eval.py +184 -0
- aiq/experimental/__init__.py +0 -0
- aiq/experimental/decorators/__init__.py +0 -0
- aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
- aiq/experimental/test_time_compute/__init__.py +0 -0
- aiq/experimental/test_time_compute/editing/__init__.py +0 -0
- aiq/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- aiq/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- aiq/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- aiq/experimental/test_time_compute/functions/__init__.py +0 -0
- aiq/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- aiq/experimental/test_time_compute/functions/its_tool_orchestration_function.py +205 -0
- aiq/experimental/test_time_compute/functions/its_tool_wrapper_function.py +146 -0
- aiq/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- aiq/experimental/test_time_compute/models/__init__.py +0 -0
- aiq/experimental/test_time_compute/models/editor_config.py +132 -0
- aiq/experimental/test_time_compute/models/scoring_config.py +112 -0
- aiq/experimental/test_time_compute/models/search_config.py +120 -0
- aiq/experimental/test_time_compute/models/selection_config.py +154 -0
- aiq/experimental/test_time_compute/models/stage_enums.py +43 -0
- aiq/experimental/test_time_compute/models/strategy_base.py +66 -0
- aiq/experimental/test_time_compute/models/tool_use_config.py +41 -0
- aiq/experimental/test_time_compute/models/ttc_item.py +48 -0
- aiq/experimental/test_time_compute/register.py +36 -0
- aiq/experimental/test_time_compute/scoring/__init__.py +0 -0
- aiq/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- aiq/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- aiq/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- aiq/experimental/test_time_compute/search/__init__.py +0 -0
- aiq/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- aiq/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- aiq/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- aiq/experimental/test_time_compute/selection/__init__.py +0 -0
- aiq/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- aiq/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- aiq/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- aiq/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- aiq/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- aiq/front_ends/__init__.py +14 -0
- aiq/front_ends/console/__init__.py +14 -0
- aiq/front_ends/console/authentication_flow_handler.py +233 -0
- aiq/front_ends/console/console_front_end_config.py +32 -0
- aiq/front_ends/console/console_front_end_plugin.py +96 -0
- aiq/front_ends/console/register.py +25 -0
- aiq/front_ends/cron/__init__.py +14 -0
- aiq/front_ends/fastapi/__init__.py +14 -0
- aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +234 -0
- aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1092 -0
- aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
- aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- aiq/front_ends/fastapi/job_store.py +183 -0
- aiq/front_ends/fastapi/main.py +72 -0
- aiq/front_ends/fastapi/message_handler.py +298 -0
- aiq/front_ends/fastapi/message_validator.py +345 -0
- aiq/front_ends/fastapi/register.py +25 -0
- aiq/front_ends/fastapi/response_helpers.py +195 -0
- aiq/front_ends/fastapi/step_adaptor.py +321 -0
- aiq/front_ends/mcp/__init__.py +14 -0
- aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
- aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
- aiq/front_ends/mcp/register.py +27 -0
- aiq/front_ends/mcp/tool_converter.py +242 -0
- aiq/front_ends/register.py +22 -0
- aiq/front_ends/simple_base/__init__.py +14 -0
- aiq/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- aiq/llm/__init__.py +0 -0
- aiq/llm/aws_bedrock_llm.py +57 -0
- aiq/llm/nim_llm.py +46 -0
- aiq/llm/openai_llm.py +46 -0
- aiq/llm/register.py +23 -0
- aiq/llm/utils/__init__.py +14 -0
- aiq/llm/utils/env_config_value.py +94 -0
- aiq/llm/utils/error.py +17 -0
- aiq/memory/__init__.py +20 -0
- aiq/memory/interfaces.py +183 -0
- aiq/memory/models.py +112 -0
- aiq/meta/module_to_distro.json +3 -0
- aiq/meta/pypi.md +58 -0
- aiq/object_store/__init__.py +20 -0
- aiq/object_store/in_memory_object_store.py +76 -0
- aiq/object_store/interfaces.py +84 -0
- aiq/object_store/models.py +36 -0
- aiq/object_store/register.py +20 -0
- aiq/observability/__init__.py +14 -0
- aiq/observability/exporter/__init__.py +14 -0
- aiq/observability/exporter/base_exporter.py +449 -0
- aiq/observability/exporter/exporter.py +78 -0
- aiq/observability/exporter/file_exporter.py +33 -0
- aiq/observability/exporter/processing_exporter.py +322 -0
- aiq/observability/exporter/raw_exporter.py +52 -0
- aiq/observability/exporter/span_exporter.py +265 -0
- aiq/observability/exporter_manager.py +335 -0
- aiq/observability/mixin/__init__.py +14 -0
- aiq/observability/mixin/batch_config_mixin.py +26 -0
- aiq/observability/mixin/collector_config_mixin.py +23 -0
- aiq/observability/mixin/file_mixin.py +288 -0
- aiq/observability/mixin/file_mode.py +23 -0
- aiq/observability/mixin/resource_conflict_mixin.py +134 -0
- aiq/observability/mixin/serialize_mixin.py +61 -0
- aiq/observability/mixin/type_introspection_mixin.py +183 -0
- aiq/observability/processor/__init__.py +14 -0
- aiq/observability/processor/batching_processor.py +310 -0
- aiq/observability/processor/callback_processor.py +42 -0
- aiq/observability/processor/intermediate_step_serializer.py +28 -0
- aiq/observability/processor/processor.py +71 -0
- aiq/observability/register.py +96 -0
- aiq/observability/utils/__init__.py +14 -0
- aiq/observability/utils/dict_utils.py +236 -0
- aiq/observability/utils/time_utils.py +31 -0
- aiq/plugins/.namespace +1 -0
- aiq/profiler/__init__.py +0 -0
- aiq/profiler/calc/__init__.py +14 -0
- aiq/profiler/calc/calc_runner.py +627 -0
- aiq/profiler/calc/calculations.py +288 -0
- aiq/profiler/calc/data_models.py +188 -0
- aiq/profiler/calc/plot.py +345 -0
- aiq/profiler/callbacks/__init__.py +0 -0
- aiq/profiler/callbacks/agno_callback_handler.py +295 -0
- aiq/profiler/callbacks/base_callback_class.py +20 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +290 -0
- aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
- aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- aiq/profiler/callbacks/token_usage_base_model.py +27 -0
- aiq/profiler/data_frame_row.py +51 -0
- aiq/profiler/data_models.py +24 -0
- aiq/profiler/decorators/__init__.py +0 -0
- aiq/profiler/decorators/framework_wrapper.py +131 -0
- aiq/profiler/decorators/function_tracking.py +254 -0
- aiq/profiler/forecasting/__init__.py +0 -0
- aiq/profiler/forecasting/config.py +18 -0
- aiq/profiler/forecasting/model_trainer.py +75 -0
- aiq/profiler/forecasting/models/__init__.py +22 -0
- aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
- aiq/profiler/forecasting/models/linear_model.py +196 -0
- aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
- aiq/profiler/inference_metrics_model.py +28 -0
- aiq/profiler/inference_optimization/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- aiq/profiler/inference_optimization/data_models.py +386 -0
- aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
- aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- aiq/profiler/inference_optimization/llm_metrics.py +212 -0
- aiq/profiler/inference_optimization/prompt_caching.py +163 -0
- aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
- aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
- aiq/profiler/intermediate_property_adapter.py +102 -0
- aiq/profiler/profile_runner.py +473 -0
- aiq/profiler/utils.py +184 -0
- aiq/registry_handlers/__init__.py +0 -0
- aiq/registry_handlers/local/__init__.py +0 -0
- aiq/registry_handlers/local/local_handler.py +176 -0
- aiq/registry_handlers/local/register_local.py +37 -0
- aiq/registry_handlers/metadata_factory.py +60 -0
- aiq/registry_handlers/package_utils.py +567 -0
- aiq/registry_handlers/pypi/__init__.py +0 -0
- aiq/registry_handlers/pypi/pypi_handler.py +251 -0
- aiq/registry_handlers/pypi/register_pypi.py +40 -0
- aiq/registry_handlers/register.py +21 -0
- aiq/registry_handlers/registry_handler_base.py +157 -0
- aiq/registry_handlers/rest/__init__.py +0 -0
- aiq/registry_handlers/rest/register_rest.py +56 -0
- aiq/registry_handlers/rest/rest_handler.py +237 -0
- aiq/registry_handlers/schemas/__init__.py +0 -0
- aiq/registry_handlers/schemas/headers.py +42 -0
- aiq/registry_handlers/schemas/package.py +68 -0
- aiq/registry_handlers/schemas/publish.py +63 -0
- aiq/registry_handlers/schemas/pull.py +82 -0
- aiq/registry_handlers/schemas/remove.py +36 -0
- aiq/registry_handlers/schemas/search.py +91 -0
- aiq/registry_handlers/schemas/status.py +47 -0
- aiq/retriever/__init__.py +0 -0
- aiq/retriever/interface.py +37 -0
- aiq/retriever/milvus/__init__.py +14 -0
- aiq/retriever/milvus/register.py +81 -0
- aiq/retriever/milvus/retriever.py +228 -0
- aiq/retriever/models.py +74 -0
- aiq/retriever/nemo_retriever/__init__.py +14 -0
- aiq/retriever/nemo_retriever/register.py +60 -0
- aiq/retriever/nemo_retriever/retriever.py +190 -0
- aiq/retriever/register.py +22 -0
- aiq/runtime/__init__.py +14 -0
- aiq/runtime/loader.py +215 -0
- aiq/runtime/runner.py +190 -0
- aiq/runtime/session.py +158 -0
- aiq/runtime/user_metadata.py +130 -0
- aiq/settings/__init__.py +0 -0
- aiq/settings/global_settings.py +318 -0
- aiq/test/.namespace +1 -0
- aiq/tool/__init__.py +0 -0
- aiq/tool/chat_completion.py +74 -0
- aiq/tool/code_execution/README.md +151 -0
- aiq/tool/code_execution/__init__.py +0 -0
- aiq/tool/code_execution/code_sandbox.py +267 -0
- aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
- aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- aiq/tool/code_execution/register.py +74 -0
- aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
- aiq/tool/code_execution/utils.py +100 -0
- aiq/tool/datetime_tools.py +42 -0
- aiq/tool/document_search.py +141 -0
- aiq/tool/github_tools/__init__.py +0 -0
- aiq/tool/github_tools/create_github_commit.py +133 -0
- aiq/tool/github_tools/create_github_issue.py +87 -0
- aiq/tool/github_tools/create_github_pr.py +106 -0
- aiq/tool/github_tools/get_github_file.py +106 -0
- aiq/tool/github_tools/get_github_issue.py +166 -0
- aiq/tool/github_tools/get_github_pr.py +256 -0
- aiq/tool/github_tools/update_github_issue.py +100 -0
- aiq/tool/mcp/__init__.py +14 -0
- aiq/tool/mcp/exceptions.py +142 -0
- aiq/tool/mcp/mcp_client.py +255 -0
- aiq/tool/mcp/mcp_tool.py +96 -0
- aiq/tool/memory_tools/__init__.py +0 -0
- aiq/tool/memory_tools/add_memory_tool.py +79 -0
- aiq/tool/memory_tools/delete_memory_tool.py +67 -0
- aiq/tool/memory_tools/get_memory_tool.py +72 -0
- aiq/tool/nvidia_rag.py +95 -0
- aiq/tool/register.py +38 -0
- aiq/tool/retriever.py +89 -0
- aiq/tool/server_tools.py +66 -0
- aiq/utils/__init__.py +0 -0
- aiq/utils/data_models/__init__.py +0 -0
- aiq/utils/data_models/schema_validator.py +58 -0
- aiq/utils/debugging_utils.py +43 -0
- aiq/utils/dump_distro_mapping.py +32 -0
- aiq/utils/exception_handlers/__init__.py +0 -0
- aiq/utils/exception_handlers/automatic_retries.py +289 -0
- aiq/utils/exception_handlers/mcp.py +211 -0
- aiq/utils/exception_handlers/schemas.py +114 -0
- aiq/utils/io/__init__.py +0 -0
- aiq/utils/io/model_processing.py +28 -0
- aiq/utils/io/yaml_tools.py +119 -0
- aiq/utils/log_utils.py +37 -0
- aiq/utils/metadata_utils.py +74 -0
- aiq/utils/optional_imports.py +142 -0
- aiq/utils/producer_consumer_queue.py +178 -0
- aiq/utils/reactive/__init__.py +0 -0
- aiq/utils/reactive/base/__init__.py +0 -0
- aiq/utils/reactive/base/observable_base.py +65 -0
- aiq/utils/reactive/base/observer_base.py +55 -0
- aiq/utils/reactive/base/subject_base.py +79 -0
- aiq/utils/reactive/observable.py +59 -0
- aiq/utils/reactive/observer.py +76 -0
- aiq/utils/reactive/subject.py +131 -0
- aiq/utils/reactive/subscription.py +49 -0
- aiq/utils/settings/__init__.py +0 -0
- aiq/utils/settings/global_settings.py +197 -0
- aiq/utils/string_utils.py +38 -0
- aiq/utils/type_converter.py +290 -0
- aiq/utils/type_utils.py +484 -0
- aiq/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0rc5.dist-info/METADATA +363 -0
- nvidia_nat-1.2.0rc5.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0rc5.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0rc5.dist-info/entry_points.txt +20 -0
- nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
- nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0rc5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from abc import ABC
|
|
17
|
+
from abc import abstractmethod
|
|
18
|
+
|
|
19
|
+
from aiq.retriever.models import RetrieverOutput
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AIQRetriever(ABC):
|
|
23
|
+
"""
|
|
24
|
+
Abstract interface for interacting with data stores.
|
|
25
|
+
|
|
26
|
+
A Retriever is resposible for retrieving data from a configured data store.
|
|
27
|
+
|
|
28
|
+
Implemntations may integrate with vector stores or other indexing backends that allow for text-based search.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
async def search(self, query: str, **kwargs) -> RetrieverOutput:
|
|
33
|
+
"""
|
|
34
|
+
Retireve max(top_k) items from the data store based on vector similarity search (implementation dependent).
|
|
35
|
+
|
|
36
|
+
"""
|
|
37
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
from pydantic import HttpUrl
|
|
18
|
+
|
|
19
|
+
from aiq.builder.builder import Builder
|
|
20
|
+
from aiq.builder.builder import LLMFrameworkEnum
|
|
21
|
+
from aiq.builder.retriever import RetrieverProviderInfo
|
|
22
|
+
from aiq.cli.register_workflow import register_retriever_client
|
|
23
|
+
from aiq.cli.register_workflow import register_retriever_provider
|
|
24
|
+
from aiq.data_models.retriever import RetrieverBaseConfig
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MilvusRetrieverConfig(RetrieverBaseConfig, name="milvus_retriever"):
|
|
28
|
+
"""
|
|
29
|
+
Configuration for a Retriever which pulls data from a Milvus service.
|
|
30
|
+
"""
|
|
31
|
+
uri: HttpUrl = Field(description="The uri of Milvus service")
|
|
32
|
+
connection_args: dict = Field(
|
|
33
|
+
description="Dictionary of arguments used to connect to and authenticate with the Milvus service",
|
|
34
|
+
default={},
|
|
35
|
+
)
|
|
36
|
+
embedding_model: str = Field(description="The name of the embedding model to use for vectorizing the query")
|
|
37
|
+
collection_name: str | None = Field(description="The name of the milvus collection to search", default=None)
|
|
38
|
+
content_field: str = Field(description="Name of the primary field to store/retrieve",
|
|
39
|
+
default="text",
|
|
40
|
+
alias="primary_field")
|
|
41
|
+
top_k: int | None = Field(gt=0, description="The number of results to return", default=None)
|
|
42
|
+
output_fields: list[str] | None = Field(
|
|
43
|
+
default=None,
|
|
44
|
+
description="A list of fields to return from the datastore. If 'None', all fields but the vector are returned.")
|
|
45
|
+
search_params: dict = Field(default={"metric_type": "L2"},
|
|
46
|
+
description="Search parameters to use when performing vector search")
|
|
47
|
+
vector_field: str = Field(default="vector", description="Name of the field to compare with the vectorized query")
|
|
48
|
+
description: str | None = Field(default=None,
|
|
49
|
+
description="If present it will be used as the tool description",
|
|
50
|
+
alias="collection_description")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@register_retriever_provider(config_type=MilvusRetrieverConfig)
|
|
54
|
+
async def milvus_retriever(retriever_config: MilvusRetrieverConfig, builder: Builder):
|
|
55
|
+
yield RetrieverProviderInfo(config=retriever_config,
|
|
56
|
+
description="An adapter for a Miluvs data store to use with a Retriever Client")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@register_retriever_client(config_type=MilvusRetrieverConfig, wrapper_type=None)
|
|
60
|
+
async def milvus_retriever_client(config: MilvusRetrieverConfig, builder: Builder):
|
|
61
|
+
from pymilvus import MilvusClient
|
|
62
|
+
|
|
63
|
+
from aiq.retriever.milvus.retriever import MilvusRetriever
|
|
64
|
+
|
|
65
|
+
embedder = await builder.get_embedder(embedder_name=config.embedding_model, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
66
|
+
|
|
67
|
+
milvus_client = MilvusClient(uri=str(config.uri), **config.connection_args)
|
|
68
|
+
retriever = MilvusRetriever(
|
|
69
|
+
client=milvus_client,
|
|
70
|
+
embedder=embedder,
|
|
71
|
+
content_field=config.content_field,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Using parameters in the config to set default values which can be overridden during the function call.
|
|
75
|
+
optional_fields = ["collection_name", "top_k", "output_fields", "search_params", "vector_field"]
|
|
76
|
+
model_dict = config.model_dump()
|
|
77
|
+
optional_args = {field: model_dict[field] for field in optional_fields if model_dict[field] is not None}
|
|
78
|
+
|
|
79
|
+
retriever.bind(**optional_args)
|
|
80
|
+
|
|
81
|
+
yield retriever
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from functools import partial
|
|
18
|
+
|
|
19
|
+
from langchain_core.embeddings import Embeddings
|
|
20
|
+
from pymilvus import MilvusClient
|
|
21
|
+
from pymilvus.client.abstract import Hit
|
|
22
|
+
|
|
23
|
+
from aiq.retriever.interface import AIQRetriever
|
|
24
|
+
from aiq.retriever.models import AIQDocument
|
|
25
|
+
from aiq.retriever.models import RetrieverError
|
|
26
|
+
from aiq.retriever.models import RetrieverOutput
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CollectionNotFoundError(RetrieverError):
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MilvusRetriever(AIQRetriever):
|
|
36
|
+
"""
|
|
37
|
+
Client for retrieving document chunks from a Milvus vectorstore
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
client: MilvusClient,
|
|
43
|
+
embedder: Embeddings,
|
|
44
|
+
content_field: str = "text",
|
|
45
|
+
use_iterator: bool = False,
|
|
46
|
+
) -> None:
|
|
47
|
+
"""
|
|
48
|
+
Initialize the Milvus Retriever using a preconfigured MilvusClient
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
client (MilvusClient): Preinstantiate pymilvus.MilvusClient object.
|
|
52
|
+
"""
|
|
53
|
+
self._client = client
|
|
54
|
+
self._embedder = embedder
|
|
55
|
+
|
|
56
|
+
if use_iterator and "search_iterator" not in dir(self._client):
|
|
57
|
+
raise ValueError("This version of the pymilvus.MilvusClient does not support the search iterator.")
|
|
58
|
+
|
|
59
|
+
self._search_func = self._search if not use_iterator else self._search_with_iterator
|
|
60
|
+
self._default_params = None
|
|
61
|
+
self._bound_params = []
|
|
62
|
+
self.content_field = content_field
|
|
63
|
+
logger.info("Mivlus Retriever using %s for search.", self._search_func.__name__)
|
|
64
|
+
|
|
65
|
+
def bind(self, **kwargs) -> None:
|
|
66
|
+
"""
|
|
67
|
+
Bind default values to the search method. Cannot bind the 'query' parameter.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
kwargs (dict): Key value pairs corresponding to the default values of search parameters.
|
|
71
|
+
"""
|
|
72
|
+
if "query" in kwargs:
|
|
73
|
+
kwargs = {k: v for k, v in kwargs.items() if k != "query"}
|
|
74
|
+
self._search_func = partial(self._search_func, **kwargs)
|
|
75
|
+
self._bound_params = list(kwargs.keys())
|
|
76
|
+
logger.debug("Binding paramaters for search function: %s", kwargs)
|
|
77
|
+
|
|
78
|
+
def get_unbound_params(self) -> list[str]:
|
|
79
|
+
"""
|
|
80
|
+
Returns a list of unbound parameters which will need to be passed to the search function.
|
|
81
|
+
"""
|
|
82
|
+
return [param for param in ["query", "collection_name", "top_k", "filters"] if param not in self._bound_params]
|
|
83
|
+
|
|
84
|
+
def _validate_collection(self, collection_name: str) -> bool:
|
|
85
|
+
return collection_name in self._client.list_collections()
|
|
86
|
+
|
|
87
|
+
async def search(self, query: str, **kwargs):
|
|
88
|
+
return await self._search_func(query=query, **kwargs)
|
|
89
|
+
|
|
90
|
+
async def _search_with_iterator(self,
|
|
91
|
+
query: str,
|
|
92
|
+
*,
|
|
93
|
+
collection_name: str,
|
|
94
|
+
top_k: int,
|
|
95
|
+
filters: str | None = None,
|
|
96
|
+
output_fields: list[str] | None = None,
|
|
97
|
+
search_params: dict | None = None,
|
|
98
|
+
timeout: float | None = None,
|
|
99
|
+
vector_field_name: str | None = "vector",
|
|
100
|
+
distance_cutoff: float | None = None,
|
|
101
|
+
**kwargs):
|
|
102
|
+
"""
|
|
103
|
+
Retrieve document chunks from a Milvus vectorstore using a search iterator, allowing for the retrieval of more
|
|
104
|
+
results.
|
|
105
|
+
"""
|
|
106
|
+
logger.debug("MilvusRetriever searching query: %s, for collection: %s. Returning max %s results",
|
|
107
|
+
query,
|
|
108
|
+
collection_name,
|
|
109
|
+
top_k)
|
|
110
|
+
|
|
111
|
+
if not self._validate_collection(collection_name):
|
|
112
|
+
raise CollectionNotFoundError(f"Collection: {collection_name} does not exist")
|
|
113
|
+
|
|
114
|
+
# If no output fields are specified, return all of them
|
|
115
|
+
if not output_fields:
|
|
116
|
+
collection_schema = self._client.describe_collection(collection_name)
|
|
117
|
+
output_fields = [
|
|
118
|
+
field["name"] for field in collection_schema.get("fields") if field["name"] != vector_field_name
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
search_vector = self._embedder.embed_query(query)
|
|
122
|
+
|
|
123
|
+
search_iterator = self._client.search_iterator(
|
|
124
|
+
collection_name=collection_name,
|
|
125
|
+
data=[search_vector],
|
|
126
|
+
batch_size=kwargs.get("batch_size", 1000),
|
|
127
|
+
filter=filters,
|
|
128
|
+
limit=top_k,
|
|
129
|
+
output_fields=output_fields,
|
|
130
|
+
search_params=search_params if search_params else {"metric_type": "L2"},
|
|
131
|
+
timeout=timeout,
|
|
132
|
+
anns_field=vector_field_name,
|
|
133
|
+
round_decimal=kwargs.get("round_decimal", -1),
|
|
134
|
+
partition_names=kwargs.get("partition_names", None),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
results = []
|
|
138
|
+
try:
|
|
139
|
+
while True:
|
|
140
|
+
_res = search_iterator.next()
|
|
141
|
+
res = _res.get_res()
|
|
142
|
+
if len(_res) == 0:
|
|
143
|
+
search_iterator.close()
|
|
144
|
+
break
|
|
145
|
+
|
|
146
|
+
if distance_cutoff and res[0][-1].distance > distance_cutoff:
|
|
147
|
+
for i in range(len(res[0])):
|
|
148
|
+
if res[0][i].distance > distance_cutoff:
|
|
149
|
+
break
|
|
150
|
+
results.append(res[0][i])
|
|
151
|
+
break
|
|
152
|
+
results.extend(res[0])
|
|
153
|
+
|
|
154
|
+
return _wrap_milvus_results(results, content_field=self.content_field)
|
|
155
|
+
|
|
156
|
+
except Exception as e:
|
|
157
|
+
logger.exception("Exception when retrieving results from milvus for query %s: %s", query, e)
|
|
158
|
+
raise RetrieverError(f"Error when retrieving documents from {collection_name} for query '{query}'") from e
|
|
159
|
+
|
|
160
|
+
async def _search(self,
|
|
161
|
+
query: str,
|
|
162
|
+
*,
|
|
163
|
+
collection_name: str,
|
|
164
|
+
top_k: int,
|
|
165
|
+
filters: str | None = None,
|
|
166
|
+
output_fields: list[str] | None = None,
|
|
167
|
+
search_params: dict | None = None,
|
|
168
|
+
timeout: float | None = None,
|
|
169
|
+
vector_field_name: str | None = "vector",
|
|
170
|
+
**kwargs):
|
|
171
|
+
"""
|
|
172
|
+
Retrieve document chunks from a Milvus vectorstore
|
|
173
|
+
"""
|
|
174
|
+
logger.debug("MilvusRetriever searching query: %s, for collection: %s. Returning max %s results",
|
|
175
|
+
query,
|
|
176
|
+
collection_name,
|
|
177
|
+
top_k)
|
|
178
|
+
|
|
179
|
+
if not self._validate_collection(collection_name):
|
|
180
|
+
raise CollectionNotFoundError(f"Collection: {collection_name} does not exist")
|
|
181
|
+
|
|
182
|
+
available_fields = [v.get("name") for v in self._client.describe_collection(collection_name).get("fields", {})]
|
|
183
|
+
|
|
184
|
+
if self.content_field not in available_fields:
|
|
185
|
+
raise ValueError(f"The specified content field: {self.content_field} is not part of the schema.")
|
|
186
|
+
|
|
187
|
+
if vector_field_name not in available_fields:
|
|
188
|
+
raise ValueError(f"The specified vector field name: {vector_field_name} is not part of the schema.")
|
|
189
|
+
|
|
190
|
+
# If no output fields are specified, return all of them
|
|
191
|
+
if not output_fields:
|
|
192
|
+
output_fields = [field for field in available_fields if field != vector_field_name]
|
|
193
|
+
|
|
194
|
+
if self.content_field not in output_fields:
|
|
195
|
+
output_fields.append(self.content_field)
|
|
196
|
+
|
|
197
|
+
search_vector = self._embedder.embed_query(query)
|
|
198
|
+
res = self._client.search(
|
|
199
|
+
collection_name=collection_name,
|
|
200
|
+
data=[search_vector],
|
|
201
|
+
filter=filters,
|
|
202
|
+
output_fields=output_fields,
|
|
203
|
+
search_params=search_params if search_params else {"metric_type": "L2"},
|
|
204
|
+
timeout=timeout,
|
|
205
|
+
anns_field=vector_field_name,
|
|
206
|
+
limit=top_k,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
return _wrap_milvus_results(res[0], content_field=self.content_field)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _wrap_milvus_results(res: list[Hit], content_field: str):
|
|
213
|
+
return RetrieverOutput(results=[_wrap_milvus_single_results(r, content_field=content_field) for r in res])
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def _wrap_milvus_single_results(res: Hit | dict, content_field: str) -> AIQDocument:
|
|
217
|
+
if not isinstance(res, (Hit, dict)):
|
|
218
|
+
raise ValueError(f"Milvus search returned object of type {type(res)}. Expected 'Hit' or 'dict'.")
|
|
219
|
+
|
|
220
|
+
if isinstance(res, Hit):
|
|
221
|
+
metadata = {k: v for k, v in res.fields.items() if k != content_field}
|
|
222
|
+
metadata.update({"distance": res.distance})
|
|
223
|
+
return AIQDocument(page_content=res.fields[content_field], metadata=metadata, document_id=res.id)
|
|
224
|
+
|
|
225
|
+
fields = res["entity"]
|
|
226
|
+
metadata = {k: v for k, v in fields.items() if k != content_field}
|
|
227
|
+
metadata.update({"distance": res.get("distance")})
|
|
228
|
+
return AIQDocument(page_content=fields.get(content_field), metadata=metadata, document_id=res["id"])
|
aiq/retriever/models.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import json
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
from pydantic import Field
|
|
23
|
+
|
|
24
|
+
from aiq.utils.type_converter import GlobalTypeConverter
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AIQDocument(BaseModel):
|
|
28
|
+
"""
|
|
29
|
+
Object representing a retrieved document/chunk from a standard AIQ Toolkit Retriever.
|
|
30
|
+
"""
|
|
31
|
+
page_content: str = Field(description="Primary content of the document to insert or retrieve")
|
|
32
|
+
metadata: dict[str, Any] = Field(description="Metadata dictionary attached to the AIQDocument")
|
|
33
|
+
document_id: str | None = Field(description="Unique ID for the document, if supported by the configured datastore",
|
|
34
|
+
default=None)
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def from_dict(cls, data: dict[str, Any]) -> AIQDocument:
|
|
38
|
+
"""
|
|
39
|
+
Deserialize an AIQDocument from a dictionary representation.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
data (dict): A dictionary containing keys
|
|
43
|
+
'page_content', 'metadata', and optionally 'document_id'.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
MemoryItem: A reconstructed MemoryItem instance.
|
|
47
|
+
"""
|
|
48
|
+
return cls(**data)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class RetrieverOutput(BaseModel):
|
|
52
|
+
results: list[AIQDocument] = Field(description="A list of retrieved AIQDocuments")
|
|
53
|
+
|
|
54
|
+
def __len__(self):
|
|
55
|
+
return len(self.results)
|
|
56
|
+
|
|
57
|
+
def __str__(self):
|
|
58
|
+
return json.dumps(self.model_dump())
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class RetrieverError(Exception):
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def retriever_output_to_dict(obj: RetrieverOutput) -> dict:
|
|
66
|
+
return obj.model_dump()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def retriever_output_to_str(obj: RetrieverOutput) -> str:
|
|
70
|
+
return str(obj)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
GlobalTypeConverter.register_converter(retriever_output_to_dict)
|
|
74
|
+
GlobalTypeConverter.register_converter(retriever_output_to_str)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
from pydantic import HttpUrl
|
|
18
|
+
|
|
19
|
+
from aiq.builder.builder import Builder
|
|
20
|
+
from aiq.builder.retriever import RetrieverProviderInfo
|
|
21
|
+
from aiq.cli.register_workflow import register_retriever_client
|
|
22
|
+
from aiq.cli.register_workflow import register_retriever_provider
|
|
23
|
+
from aiq.data_models.retriever import RetrieverBaseConfig
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class NemoRetrieverConfig(RetrieverBaseConfig, name="nemo_retriever"):
|
|
27
|
+
"""
|
|
28
|
+
Configuration for a Retriever which pulls data from a Nemo Retriever service.
|
|
29
|
+
"""
|
|
30
|
+
uri: HttpUrl = Field(description="The uri of the Nemo Retriever service.")
|
|
31
|
+
collection_name: str | None = Field(description="The name of the collection to search", default=None)
|
|
32
|
+
top_k: int | None = Field(description="The number of results to return", gt=0, le=50, default=None)
|
|
33
|
+
output_fields: list[str] | None = Field(
|
|
34
|
+
default=None,
|
|
35
|
+
description="A list of fields to return from the datastore. If 'None', all fields but the vector are returned.")
|
|
36
|
+
timeout: int = Field(default=60, description="Maximum time to wait for results to be returned from the service.")
|
|
37
|
+
nvidia_api_key: str | None = Field(
|
|
38
|
+
description="API key used to authenticate with the service. If 'None', will use ENV Variable 'NVIDIA_API_KEY'",
|
|
39
|
+
default=None,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@register_retriever_provider(config_type=NemoRetrieverConfig)
|
|
44
|
+
async def nemo_retriever(retriever_config: NemoRetrieverConfig, builder: Builder):
|
|
45
|
+
yield RetrieverProviderInfo(config=retriever_config,
|
|
46
|
+
description="An adapter for a Nemo data store for use with a Retriever Client")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@register_retriever_client(config_type=NemoRetrieverConfig, wrapper_type=None)
|
|
50
|
+
async def nemo_retriever_client(config: NemoRetrieverConfig, builder: Builder):
|
|
51
|
+
from aiq.retriever.nemo_retriever.retriever import NemoRetriever
|
|
52
|
+
|
|
53
|
+
retriever = NemoRetriever(**config.model_dump(exclude={"type", "top_k", "collection_name"}))
|
|
54
|
+
optional_fields = ["collection_name", "top_k", "output_fields"]
|
|
55
|
+
model_dict = config.model_dump()
|
|
56
|
+
optional_args = {field: model_dict[field] for field in optional_fields if model_dict[field] is not None}
|
|
57
|
+
|
|
58
|
+
retriever.bind(**optional_args)
|
|
59
|
+
|
|
60
|
+
yield retriever
|