nvidia-nat 1.2.0rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/agent/__init__.py +0 -0
- aiq/agent/base.py +239 -0
- aiq/agent/dual_node.py +67 -0
- aiq/agent/react_agent/__init__.py +0 -0
- aiq/agent/react_agent/agent.py +355 -0
- aiq/agent/react_agent/output_parser.py +104 -0
- aiq/agent/react_agent/prompt.py +41 -0
- aiq/agent/react_agent/register.py +149 -0
- aiq/agent/reasoning_agent/__init__.py +0 -0
- aiq/agent/reasoning_agent/reasoning_agent.py +225 -0
- aiq/agent/register.py +23 -0
- aiq/agent/rewoo_agent/__init__.py +0 -0
- aiq/agent/rewoo_agent/agent.py +411 -0
- aiq/agent/rewoo_agent/prompt.py +108 -0
- aiq/agent/rewoo_agent/register.py +158 -0
- aiq/agent/tool_calling_agent/__init__.py +0 -0
- aiq/agent/tool_calling_agent/agent.py +119 -0
- aiq/agent/tool_calling_agent/register.py +106 -0
- aiq/authentication/__init__.py +14 -0
- aiq/authentication/api_key/__init__.py +14 -0
- aiq/authentication/api_key/api_key_auth_provider.py +96 -0
- aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
- aiq/authentication/api_key/register.py +26 -0
- aiq/authentication/exceptions/__init__.py +14 -0
- aiq/authentication/exceptions/api_key_exceptions.py +38 -0
- aiq/authentication/http_basic_auth/__init__.py +0 -0
- aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- aiq/authentication/http_basic_auth/register.py +30 -0
- aiq/authentication/interfaces.py +93 -0
- aiq/authentication/oauth2/__init__.py +14 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- aiq/authentication/oauth2/register.py +25 -0
- aiq/authentication/register.py +21 -0
- aiq/builder/__init__.py +0 -0
- aiq/builder/builder.py +285 -0
- aiq/builder/component_utils.py +316 -0
- aiq/builder/context.py +264 -0
- aiq/builder/embedder.py +24 -0
- aiq/builder/eval_builder.py +161 -0
- aiq/builder/evaluator.py +29 -0
- aiq/builder/framework_enum.py +24 -0
- aiq/builder/front_end.py +73 -0
- aiq/builder/function.py +344 -0
- aiq/builder/function_base.py +380 -0
- aiq/builder/function_info.py +627 -0
- aiq/builder/intermediate_step_manager.py +174 -0
- aiq/builder/llm.py +25 -0
- aiq/builder/retriever.py +25 -0
- aiq/builder/user_interaction_manager.py +74 -0
- aiq/builder/workflow.py +148 -0
- aiq/builder/workflow_builder.py +1117 -0
- aiq/cli/__init__.py +14 -0
- aiq/cli/cli_utils/__init__.py +0 -0
- aiq/cli/cli_utils/config_override.py +231 -0
- aiq/cli/cli_utils/validation.py +37 -0
- aiq/cli/commands/__init__.py +0 -0
- aiq/cli/commands/configure/__init__.py +0 -0
- aiq/cli/commands/configure/channel/__init__.py +0 -0
- aiq/cli/commands/configure/channel/add.py +28 -0
- aiq/cli/commands/configure/channel/channel.py +36 -0
- aiq/cli/commands/configure/channel/remove.py +30 -0
- aiq/cli/commands/configure/channel/update.py +30 -0
- aiq/cli/commands/configure/configure.py +33 -0
- aiq/cli/commands/evaluate.py +139 -0
- aiq/cli/commands/info/__init__.py +14 -0
- aiq/cli/commands/info/info.py +39 -0
- aiq/cli/commands/info/list_channels.py +32 -0
- aiq/cli/commands/info/list_components.py +129 -0
- aiq/cli/commands/info/list_mcp.py +213 -0
- aiq/cli/commands/registry/__init__.py +14 -0
- aiq/cli/commands/registry/publish.py +88 -0
- aiq/cli/commands/registry/pull.py +118 -0
- aiq/cli/commands/registry/registry.py +38 -0
- aiq/cli/commands/registry/remove.py +108 -0
- aiq/cli/commands/registry/search.py +155 -0
- aiq/cli/commands/sizing/__init__.py +14 -0
- aiq/cli/commands/sizing/calc.py +297 -0
- aiq/cli/commands/sizing/sizing.py +27 -0
- aiq/cli/commands/start.py +246 -0
- aiq/cli/commands/uninstall.py +81 -0
- aiq/cli/commands/validate.py +47 -0
- aiq/cli/commands/workflow/__init__.py +14 -0
- aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
- aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
- aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- aiq/cli/commands/workflow/workflow.py +37 -0
- aiq/cli/commands/workflow/workflow_commands.py +313 -0
- aiq/cli/entrypoint.py +135 -0
- aiq/cli/main.py +44 -0
- aiq/cli/register_workflow.py +488 -0
- aiq/cli/type_registry.py +1000 -0
- aiq/data_models/__init__.py +14 -0
- aiq/data_models/api_server.py +694 -0
- aiq/data_models/authentication.py +231 -0
- aiq/data_models/common.py +171 -0
- aiq/data_models/component.py +54 -0
- aiq/data_models/component_ref.py +168 -0
- aiq/data_models/config.py +406 -0
- aiq/data_models/dataset_handler.py +123 -0
- aiq/data_models/discovery_metadata.py +335 -0
- aiq/data_models/embedder.py +27 -0
- aiq/data_models/evaluate.py +127 -0
- aiq/data_models/evaluator.py +26 -0
- aiq/data_models/front_end.py +26 -0
- aiq/data_models/function.py +30 -0
- aiq/data_models/function_dependencies.py +72 -0
- aiq/data_models/interactive.py +246 -0
- aiq/data_models/intermediate_step.py +302 -0
- aiq/data_models/invocation_node.py +38 -0
- aiq/data_models/llm.py +27 -0
- aiq/data_models/logging.py +26 -0
- aiq/data_models/memory.py +27 -0
- aiq/data_models/object_store.py +44 -0
- aiq/data_models/profiler.py +54 -0
- aiq/data_models/registry_handler.py +26 -0
- aiq/data_models/retriever.py +30 -0
- aiq/data_models/retry_mixin.py +35 -0
- aiq/data_models/span.py +187 -0
- aiq/data_models/step_adaptor.py +64 -0
- aiq/data_models/streaming.py +33 -0
- aiq/data_models/swe_bench_model.py +54 -0
- aiq/data_models/telemetry_exporter.py +26 -0
- aiq/data_models/ttc_strategy.py +30 -0
- aiq/embedder/__init__.py +0 -0
- aiq/embedder/langchain_client.py +41 -0
- aiq/embedder/nim_embedder.py +59 -0
- aiq/embedder/openai_embedder.py +43 -0
- aiq/embedder/register.py +24 -0
- aiq/eval/__init__.py +14 -0
- aiq/eval/config.py +60 -0
- aiq/eval/dataset_handler/__init__.py +0 -0
- aiq/eval/dataset_handler/dataset_downloader.py +106 -0
- aiq/eval/dataset_handler/dataset_filter.py +52 -0
- aiq/eval/dataset_handler/dataset_handler.py +254 -0
- aiq/eval/evaluate.py +506 -0
- aiq/eval/evaluator/__init__.py +14 -0
- aiq/eval/evaluator/base_evaluator.py +73 -0
- aiq/eval/evaluator/evaluator_model.py +45 -0
- aiq/eval/intermediate_step_adapter.py +99 -0
- aiq/eval/rag_evaluator/__init__.py +0 -0
- aiq/eval/rag_evaluator/evaluate.py +178 -0
- aiq/eval/rag_evaluator/register.py +143 -0
- aiq/eval/register.py +23 -0
- aiq/eval/remote_workflow.py +133 -0
- aiq/eval/runners/__init__.py +14 -0
- aiq/eval/runners/config.py +39 -0
- aiq/eval/runners/multi_eval_runner.py +54 -0
- aiq/eval/runtime_event_subscriber.py +52 -0
- aiq/eval/swe_bench_evaluator/__init__.py +0 -0
- aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
- aiq/eval/swe_bench_evaluator/register.py +36 -0
- aiq/eval/trajectory_evaluator/__init__.py +0 -0
- aiq/eval/trajectory_evaluator/evaluate.py +75 -0
- aiq/eval/trajectory_evaluator/register.py +40 -0
- aiq/eval/tunable_rag_evaluator/__init__.py +0 -0
- aiq/eval/tunable_rag_evaluator/evaluate.py +245 -0
- aiq/eval/tunable_rag_evaluator/register.py +52 -0
- aiq/eval/usage_stats.py +41 -0
- aiq/eval/utils/__init__.py +0 -0
- aiq/eval/utils/output_uploader.py +140 -0
- aiq/eval/utils/tqdm_position_registry.py +40 -0
- aiq/eval/utils/weave_eval.py +184 -0
- aiq/experimental/__init__.py +0 -0
- aiq/experimental/decorators/__init__.py +0 -0
- aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
- aiq/experimental/test_time_compute/__init__.py +0 -0
- aiq/experimental/test_time_compute/editing/__init__.py +0 -0
- aiq/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- aiq/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- aiq/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- aiq/experimental/test_time_compute/functions/__init__.py +0 -0
- aiq/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- aiq/experimental/test_time_compute/functions/its_tool_orchestration_function.py +205 -0
- aiq/experimental/test_time_compute/functions/its_tool_wrapper_function.py +146 -0
- aiq/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- aiq/experimental/test_time_compute/models/__init__.py +0 -0
- aiq/experimental/test_time_compute/models/editor_config.py +132 -0
- aiq/experimental/test_time_compute/models/scoring_config.py +112 -0
- aiq/experimental/test_time_compute/models/search_config.py +120 -0
- aiq/experimental/test_time_compute/models/selection_config.py +154 -0
- aiq/experimental/test_time_compute/models/stage_enums.py +43 -0
- aiq/experimental/test_time_compute/models/strategy_base.py +66 -0
- aiq/experimental/test_time_compute/models/tool_use_config.py +41 -0
- aiq/experimental/test_time_compute/models/ttc_item.py +48 -0
- aiq/experimental/test_time_compute/register.py +36 -0
- aiq/experimental/test_time_compute/scoring/__init__.py +0 -0
- aiq/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- aiq/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- aiq/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- aiq/experimental/test_time_compute/search/__init__.py +0 -0
- aiq/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- aiq/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- aiq/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- aiq/experimental/test_time_compute/selection/__init__.py +0 -0
- aiq/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- aiq/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- aiq/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- aiq/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- aiq/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- aiq/front_ends/__init__.py +14 -0
- aiq/front_ends/console/__init__.py +14 -0
- aiq/front_ends/console/authentication_flow_handler.py +233 -0
- aiq/front_ends/console/console_front_end_config.py +32 -0
- aiq/front_ends/console/console_front_end_plugin.py +96 -0
- aiq/front_ends/console/register.py +25 -0
- aiq/front_ends/cron/__init__.py +14 -0
- aiq/front_ends/fastapi/__init__.py +14 -0
- aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +234 -0
- aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1092 -0
- aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
- aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- aiq/front_ends/fastapi/job_store.py +183 -0
- aiq/front_ends/fastapi/main.py +72 -0
- aiq/front_ends/fastapi/message_handler.py +298 -0
- aiq/front_ends/fastapi/message_validator.py +345 -0
- aiq/front_ends/fastapi/register.py +25 -0
- aiq/front_ends/fastapi/response_helpers.py +195 -0
- aiq/front_ends/fastapi/step_adaptor.py +321 -0
- aiq/front_ends/mcp/__init__.py +14 -0
- aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
- aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
- aiq/front_ends/mcp/register.py +27 -0
- aiq/front_ends/mcp/tool_converter.py +242 -0
- aiq/front_ends/register.py +22 -0
- aiq/front_ends/simple_base/__init__.py +14 -0
- aiq/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- aiq/llm/__init__.py +0 -0
- aiq/llm/aws_bedrock_llm.py +57 -0
- aiq/llm/nim_llm.py +46 -0
- aiq/llm/openai_llm.py +46 -0
- aiq/llm/register.py +23 -0
- aiq/llm/utils/__init__.py +14 -0
- aiq/llm/utils/env_config_value.py +94 -0
- aiq/llm/utils/error.py +17 -0
- aiq/memory/__init__.py +20 -0
- aiq/memory/interfaces.py +183 -0
- aiq/memory/models.py +112 -0
- aiq/meta/module_to_distro.json +3 -0
- aiq/meta/pypi.md +58 -0
- aiq/object_store/__init__.py +20 -0
- aiq/object_store/in_memory_object_store.py +76 -0
- aiq/object_store/interfaces.py +84 -0
- aiq/object_store/models.py +36 -0
- aiq/object_store/register.py +20 -0
- aiq/observability/__init__.py +14 -0
- aiq/observability/exporter/__init__.py +14 -0
- aiq/observability/exporter/base_exporter.py +449 -0
- aiq/observability/exporter/exporter.py +78 -0
- aiq/observability/exporter/file_exporter.py +33 -0
- aiq/observability/exporter/processing_exporter.py +322 -0
- aiq/observability/exporter/raw_exporter.py +52 -0
- aiq/observability/exporter/span_exporter.py +265 -0
- aiq/observability/exporter_manager.py +335 -0
- aiq/observability/mixin/__init__.py +14 -0
- aiq/observability/mixin/batch_config_mixin.py +26 -0
- aiq/observability/mixin/collector_config_mixin.py +23 -0
- aiq/observability/mixin/file_mixin.py +288 -0
- aiq/observability/mixin/file_mode.py +23 -0
- aiq/observability/mixin/resource_conflict_mixin.py +134 -0
- aiq/observability/mixin/serialize_mixin.py +61 -0
- aiq/observability/mixin/type_introspection_mixin.py +183 -0
- aiq/observability/processor/__init__.py +14 -0
- aiq/observability/processor/batching_processor.py +310 -0
- aiq/observability/processor/callback_processor.py +42 -0
- aiq/observability/processor/intermediate_step_serializer.py +28 -0
- aiq/observability/processor/processor.py +71 -0
- aiq/observability/register.py +96 -0
- aiq/observability/utils/__init__.py +14 -0
- aiq/observability/utils/dict_utils.py +236 -0
- aiq/observability/utils/time_utils.py +31 -0
- aiq/plugins/.namespace +1 -0
- aiq/profiler/__init__.py +0 -0
- aiq/profiler/calc/__init__.py +14 -0
- aiq/profiler/calc/calc_runner.py +627 -0
- aiq/profiler/calc/calculations.py +288 -0
- aiq/profiler/calc/data_models.py +188 -0
- aiq/profiler/calc/plot.py +345 -0
- aiq/profiler/callbacks/__init__.py +0 -0
- aiq/profiler/callbacks/agno_callback_handler.py +295 -0
- aiq/profiler/callbacks/base_callback_class.py +20 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +290 -0
- aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
- aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- aiq/profiler/callbacks/token_usage_base_model.py +27 -0
- aiq/profiler/data_frame_row.py +51 -0
- aiq/profiler/data_models.py +24 -0
- aiq/profiler/decorators/__init__.py +0 -0
- aiq/profiler/decorators/framework_wrapper.py +131 -0
- aiq/profiler/decorators/function_tracking.py +254 -0
- aiq/profiler/forecasting/__init__.py +0 -0
- aiq/profiler/forecasting/config.py +18 -0
- aiq/profiler/forecasting/model_trainer.py +75 -0
- aiq/profiler/forecasting/models/__init__.py +22 -0
- aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
- aiq/profiler/forecasting/models/linear_model.py +196 -0
- aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
- aiq/profiler/inference_metrics_model.py +28 -0
- aiq/profiler/inference_optimization/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- aiq/profiler/inference_optimization/data_models.py +386 -0
- aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
- aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- aiq/profiler/inference_optimization/llm_metrics.py +212 -0
- aiq/profiler/inference_optimization/prompt_caching.py +163 -0
- aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
- aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
- aiq/profiler/intermediate_property_adapter.py +102 -0
- aiq/profiler/profile_runner.py +473 -0
- aiq/profiler/utils.py +184 -0
- aiq/registry_handlers/__init__.py +0 -0
- aiq/registry_handlers/local/__init__.py +0 -0
- aiq/registry_handlers/local/local_handler.py +176 -0
- aiq/registry_handlers/local/register_local.py +37 -0
- aiq/registry_handlers/metadata_factory.py +60 -0
- aiq/registry_handlers/package_utils.py +567 -0
- aiq/registry_handlers/pypi/__init__.py +0 -0
- aiq/registry_handlers/pypi/pypi_handler.py +251 -0
- aiq/registry_handlers/pypi/register_pypi.py +40 -0
- aiq/registry_handlers/register.py +21 -0
- aiq/registry_handlers/registry_handler_base.py +157 -0
- aiq/registry_handlers/rest/__init__.py +0 -0
- aiq/registry_handlers/rest/register_rest.py +56 -0
- aiq/registry_handlers/rest/rest_handler.py +237 -0
- aiq/registry_handlers/schemas/__init__.py +0 -0
- aiq/registry_handlers/schemas/headers.py +42 -0
- aiq/registry_handlers/schemas/package.py +68 -0
- aiq/registry_handlers/schemas/publish.py +63 -0
- aiq/registry_handlers/schemas/pull.py +82 -0
- aiq/registry_handlers/schemas/remove.py +36 -0
- aiq/registry_handlers/schemas/search.py +91 -0
- aiq/registry_handlers/schemas/status.py +47 -0
- aiq/retriever/__init__.py +0 -0
- aiq/retriever/interface.py +37 -0
- aiq/retriever/milvus/__init__.py +14 -0
- aiq/retriever/milvus/register.py +81 -0
- aiq/retriever/milvus/retriever.py +228 -0
- aiq/retriever/models.py +74 -0
- aiq/retriever/nemo_retriever/__init__.py +14 -0
- aiq/retriever/nemo_retriever/register.py +60 -0
- aiq/retriever/nemo_retriever/retriever.py +190 -0
- aiq/retriever/register.py +22 -0
- aiq/runtime/__init__.py +14 -0
- aiq/runtime/loader.py +215 -0
- aiq/runtime/runner.py +190 -0
- aiq/runtime/session.py +158 -0
- aiq/runtime/user_metadata.py +130 -0
- aiq/settings/__init__.py +0 -0
- aiq/settings/global_settings.py +318 -0
- aiq/test/.namespace +1 -0
- aiq/tool/__init__.py +0 -0
- aiq/tool/chat_completion.py +74 -0
- aiq/tool/code_execution/README.md +151 -0
- aiq/tool/code_execution/__init__.py +0 -0
- aiq/tool/code_execution/code_sandbox.py +267 -0
- aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
- aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- aiq/tool/code_execution/register.py +74 -0
- aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
- aiq/tool/code_execution/utils.py +100 -0
- aiq/tool/datetime_tools.py +42 -0
- aiq/tool/document_search.py +141 -0
- aiq/tool/github_tools/__init__.py +0 -0
- aiq/tool/github_tools/create_github_commit.py +133 -0
- aiq/tool/github_tools/create_github_issue.py +87 -0
- aiq/tool/github_tools/create_github_pr.py +106 -0
- aiq/tool/github_tools/get_github_file.py +106 -0
- aiq/tool/github_tools/get_github_issue.py +166 -0
- aiq/tool/github_tools/get_github_pr.py +256 -0
- aiq/tool/github_tools/update_github_issue.py +100 -0
- aiq/tool/mcp/__init__.py +14 -0
- aiq/tool/mcp/exceptions.py +142 -0
- aiq/tool/mcp/mcp_client.py +255 -0
- aiq/tool/mcp/mcp_tool.py +96 -0
- aiq/tool/memory_tools/__init__.py +0 -0
- aiq/tool/memory_tools/add_memory_tool.py +79 -0
- aiq/tool/memory_tools/delete_memory_tool.py +67 -0
- aiq/tool/memory_tools/get_memory_tool.py +72 -0
- aiq/tool/nvidia_rag.py +95 -0
- aiq/tool/register.py +38 -0
- aiq/tool/retriever.py +89 -0
- aiq/tool/server_tools.py +66 -0
- aiq/utils/__init__.py +0 -0
- aiq/utils/data_models/__init__.py +0 -0
- aiq/utils/data_models/schema_validator.py +58 -0
- aiq/utils/debugging_utils.py +43 -0
- aiq/utils/dump_distro_mapping.py +32 -0
- aiq/utils/exception_handlers/__init__.py +0 -0
- aiq/utils/exception_handlers/automatic_retries.py +289 -0
- aiq/utils/exception_handlers/mcp.py +211 -0
- aiq/utils/exception_handlers/schemas.py +114 -0
- aiq/utils/io/__init__.py +0 -0
- aiq/utils/io/model_processing.py +28 -0
- aiq/utils/io/yaml_tools.py +119 -0
- aiq/utils/log_utils.py +37 -0
- aiq/utils/metadata_utils.py +74 -0
- aiq/utils/optional_imports.py +142 -0
- aiq/utils/producer_consumer_queue.py +178 -0
- aiq/utils/reactive/__init__.py +0 -0
- aiq/utils/reactive/base/__init__.py +0 -0
- aiq/utils/reactive/base/observable_base.py +65 -0
- aiq/utils/reactive/base/observer_base.py +55 -0
- aiq/utils/reactive/base/subject_base.py +79 -0
- aiq/utils/reactive/observable.py +59 -0
- aiq/utils/reactive/observer.py +76 -0
- aiq/utils/reactive/subject.py +131 -0
- aiq/utils/reactive/subscription.py +49 -0
- aiq/utils/settings/__init__.py +0 -0
- aiq/utils/settings/global_settings.py +197 -0
- aiq/utils/string_utils.py +38 -0
- aiq/utils/type_converter.py +290 -0
- aiq/utils/type_utils.py +484 -0
- aiq/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0rc5.dist-info/METADATA +363 -0
- nvidia_nat-1.2.0rc5.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0rc5.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0rc5.dist-info/entry_points.txt +20 -0
- nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
- nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0rc5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import datetime
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import uuid
|
|
20
|
+
from typing import Any
|
|
21
|
+
from typing import Literal
|
|
22
|
+
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
from pydantic import ValidationError
|
|
25
|
+
|
|
26
|
+
from aiq.data_models.api_server import AIQChatResponse
|
|
27
|
+
from aiq.data_models.api_server import AIQChatResponseChunk
|
|
28
|
+
from aiq.data_models.api_server import AIQResponseIntermediateStep
|
|
29
|
+
from aiq.data_models.api_server import AIQResponsePayloadOutput
|
|
30
|
+
from aiq.data_models.api_server import Error
|
|
31
|
+
from aiq.data_models.api_server import ErrorTypes
|
|
32
|
+
from aiq.data_models.api_server import SystemIntermediateStepContent
|
|
33
|
+
from aiq.data_models.api_server import SystemResponseContent
|
|
34
|
+
from aiq.data_models.api_server import TextContent
|
|
35
|
+
from aiq.data_models.api_server import WebSocketMessageStatus
|
|
36
|
+
from aiq.data_models.api_server import WebSocketMessageType
|
|
37
|
+
from aiq.data_models.api_server import WebSocketSystemInteractionMessage
|
|
38
|
+
from aiq.data_models.api_server import WebSocketSystemIntermediateStepMessage
|
|
39
|
+
from aiq.data_models.api_server import WebSocketSystemResponseTokenMessage
|
|
40
|
+
from aiq.data_models.api_server import WebSocketUserInteractionResponseMessage
|
|
41
|
+
from aiq.data_models.api_server import WebSocketUserMessage
|
|
42
|
+
from aiq.data_models.api_server import WorkflowSchemaType
|
|
43
|
+
from aiq.data_models.interactive import BinaryHumanPromptOption
|
|
44
|
+
from aiq.data_models.interactive import HumanPrompt
|
|
45
|
+
from aiq.data_models.interactive import HumanPromptBase
|
|
46
|
+
from aiq.data_models.interactive import HumanPromptBinary
|
|
47
|
+
from aiq.data_models.interactive import HumanPromptCheckbox
|
|
48
|
+
from aiq.data_models.interactive import HumanPromptDropdown
|
|
49
|
+
from aiq.data_models.interactive import HumanPromptRadio
|
|
50
|
+
from aiq.data_models.interactive import HumanPromptText
|
|
51
|
+
from aiq.data_models.interactive import HumanResponse
|
|
52
|
+
from aiq.data_models.interactive import HumanResponseBinary
|
|
53
|
+
from aiq.data_models.interactive import HumanResponseCheckbox
|
|
54
|
+
from aiq.data_models.interactive import HumanResponseDropdown
|
|
55
|
+
from aiq.data_models.interactive import HumanResponseRadio
|
|
56
|
+
from aiq.data_models.interactive import HumanResponseText
|
|
57
|
+
from aiq.data_models.interactive import MultipleChoiceOption
|
|
58
|
+
|
|
59
|
+
logger = logging.getLogger(__name__)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class MessageValidator:
|
|
63
|
+
|
|
64
|
+
def __init__(self):
|
|
65
|
+
self._message_type_schema_mapping: dict[str, type[BaseModel]] = {
|
|
66
|
+
WebSocketMessageType.USER_MESSAGE: WebSocketUserMessage,
|
|
67
|
+
WebSocketMessageType.RESPONSE_MESSAGE: WebSocketSystemResponseTokenMessage,
|
|
68
|
+
WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE: WebSocketSystemIntermediateStepMessage,
|
|
69
|
+
WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE: WebSocketSystemInteractionMessage,
|
|
70
|
+
WebSocketMessageType.USER_INTERACTION_MESSAGE: WebSocketUserInteractionResponseMessage,
|
|
71
|
+
WebSocketMessageType.ERROR_MESSAGE: Error
|
|
72
|
+
}
|
|
73
|
+
self._data_type_schema_mapping: dict[str, type[BaseModel]] = {
|
|
74
|
+
WorkflowSchemaType.GENERATE: AIQResponsePayloadOutput,
|
|
75
|
+
WorkflowSchemaType.CHAT: AIQChatResponse,
|
|
76
|
+
WorkflowSchemaType.CHAT_STREAM: AIQChatResponseChunk,
|
|
77
|
+
WorkflowSchemaType.GENERATE_STREAM: AIQResponseIntermediateStep,
|
|
78
|
+
}
|
|
79
|
+
self._message_parent_id: str = "default_id"
|
|
80
|
+
|
|
81
|
+
async def validate_message(self, message: dict[str, Any]) -> BaseModel:
|
|
82
|
+
"""
|
|
83
|
+
Validates an incoming WebSocket message against its expected schema.
|
|
84
|
+
If validation fails, returns a system response error message.
|
|
85
|
+
|
|
86
|
+
:param message: Incoming WebSocket message as a dictionary.
|
|
87
|
+
:return: A validated Pydantic model.
|
|
88
|
+
"""
|
|
89
|
+
validated_message: BaseModel
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
message_type = message.get("type")
|
|
93
|
+
if not message_type:
|
|
94
|
+
raise ValueError(f"Missing message type: {json.dumps(message)}")
|
|
95
|
+
|
|
96
|
+
schema: type[BaseModel] = await self.get_message_schema_by_type(message_type)
|
|
97
|
+
|
|
98
|
+
if issubclass(schema, Error):
|
|
99
|
+
raise TypeError(
|
|
100
|
+
f"An error was encountered processing an incoming WebSocket message of type: {message_type}")
|
|
101
|
+
|
|
102
|
+
validated_message = schema(**message)
|
|
103
|
+
return validated_message
|
|
104
|
+
|
|
105
|
+
except (ValidationError, TypeError, ValueError) as e:
|
|
106
|
+
logger.error("A data validation error %s occurred for message: %s", str(e), str(message), exc_info=True)
|
|
107
|
+
return await self.create_system_response_token_message(message_type=WebSocketMessageType.ERROR_MESSAGE,
|
|
108
|
+
content=Error(code=ErrorTypes.INVALID_MESSAGE,
|
|
109
|
+
message="Error validating message.",
|
|
110
|
+
details=str(e)))
|
|
111
|
+
|
|
112
|
+
async def get_message_schema_by_type(self, message_type: str) -> type[BaseModel]:
|
|
113
|
+
"""
|
|
114
|
+
Retrieves the corresponding Pydantic model schema based on the message type.
|
|
115
|
+
|
|
116
|
+
:param message_type: The type of message as a string.
|
|
117
|
+
:return: A Pydantic schema class if found, otherwise None.
|
|
118
|
+
"""
|
|
119
|
+
try:
|
|
120
|
+
schema: type[BaseModel] | None = self._message_type_schema_mapping.get(message_type)
|
|
121
|
+
|
|
122
|
+
if schema is None:
|
|
123
|
+
raise ValueError(f"Unknown message type: {message_type}")
|
|
124
|
+
|
|
125
|
+
return schema
|
|
126
|
+
|
|
127
|
+
except (TypeError, ValueError) as e:
|
|
128
|
+
logger.error("Error retrieving schema for message type '%s': %s", message_type, str(e), exc_info=True)
|
|
129
|
+
return Error
|
|
130
|
+
|
|
131
|
+
async def convert_data_to_message_content(self, data_model: BaseModel) -> BaseModel:
|
|
132
|
+
"""
|
|
133
|
+
Converts a Pydantic data model to a WebSocket message content instance.
|
|
134
|
+
|
|
135
|
+
:param data_model: Pydantic Data Model instance.
|
|
136
|
+
:return: A WebSocket Message Content Data Model instance.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
validated_message_content: BaseModel = None
|
|
140
|
+
try:
|
|
141
|
+
if (isinstance(data_model, AIQResponsePayloadOutput)):
|
|
142
|
+
validated_message_content = SystemResponseContent(text=data_model.payload)
|
|
143
|
+
|
|
144
|
+
elif (isinstance(data_model, (AIQChatResponse, AIQChatResponseChunk))):
|
|
145
|
+
validated_message_content = SystemResponseContent(text=data_model.choices[0].message.content)
|
|
146
|
+
|
|
147
|
+
elif (isinstance(data_model, AIQResponseIntermediateStep)):
|
|
148
|
+
validated_message_content = SystemIntermediateStepContent(name=data_model.name,
|
|
149
|
+
payload=data_model.payload)
|
|
150
|
+
elif (isinstance(data_model, HumanPromptBase)):
|
|
151
|
+
validated_message_content = data_model
|
|
152
|
+
elif (isinstance(data_model, SystemResponseContent)):
|
|
153
|
+
return data_model
|
|
154
|
+
else:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"Input data could not be converted to validated message content: {data_model.model_dump_json()}")
|
|
157
|
+
|
|
158
|
+
return validated_message_content
|
|
159
|
+
|
|
160
|
+
except ValueError as e:
|
|
161
|
+
logger.error("Input data could not be converted to validated message content: %s", str(e), exc_info=True)
|
|
162
|
+
return Error(code=ErrorTypes.INVALID_DATA_CONTENT, message="Input data not supported.", details=str(e))
|
|
163
|
+
|
|
164
|
+
async def convert_text_content_to_human_response(self, text_content: TextContent,
|
|
165
|
+
human_prompt: HumanPromptBase) -> HumanResponse:
|
|
166
|
+
"""
|
|
167
|
+
Converts Message Text Content data model to a Human Response Base data model instance.
|
|
168
|
+
|
|
169
|
+
:param text_content: Pydantic TextContent Data Model instance.
|
|
170
|
+
:param human_prompt: Pydantic HumanPrompt Data Model instance.
|
|
171
|
+
:return: A Human Response Data Model instance.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
human_response: HumanResponse = None
|
|
175
|
+
try:
|
|
176
|
+
if (isinstance(human_prompt, HumanPromptText)):
|
|
177
|
+
human_response = HumanResponseText(text=text_content.text)
|
|
178
|
+
|
|
179
|
+
elif (isinstance(human_prompt, HumanPromptBinary)):
|
|
180
|
+
human_response = HumanResponseBinary(selected_option=BinaryHumanPromptOption(value=text_content.text))
|
|
181
|
+
|
|
182
|
+
elif (isinstance(human_prompt, HumanPromptRadio)):
|
|
183
|
+
human_response = HumanResponseRadio(selected_option=MultipleChoiceOption(value=text_content.text))
|
|
184
|
+
|
|
185
|
+
elif (isinstance(human_prompt, HumanPromptCheckbox)):
|
|
186
|
+
human_response = HumanResponseCheckbox(selected_option=MultipleChoiceOption(value=text_content.text))
|
|
187
|
+
|
|
188
|
+
elif (isinstance(human_prompt, HumanPromptDropdown)):
|
|
189
|
+
human_response = HumanResponseDropdown(selected_option=MultipleChoiceOption(value=text_content.text))
|
|
190
|
+
else:
|
|
191
|
+
raise ValueError("Message content type not found")
|
|
192
|
+
|
|
193
|
+
return human_response
|
|
194
|
+
|
|
195
|
+
except ValueError as e:
|
|
196
|
+
logger.error("Error human response content not found: %s", str(e), exc_info=True)
|
|
197
|
+
return HumanResponseText(text=str(e))
|
|
198
|
+
|
|
199
|
+
async def resolve_message_type_by_data(self, data_model: BaseModel) -> str:
|
|
200
|
+
"""
|
|
201
|
+
Resolve message type from a validated model
|
|
202
|
+
|
|
203
|
+
:param data_model: Pydantic Data Model instance.
|
|
204
|
+
:return: A WebSocket Message Content Data Model instance.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
validated_message_type: str = ""
|
|
208
|
+
try:
|
|
209
|
+
if (isinstance(data_model, (AIQResponsePayloadOutput, AIQChatResponse, AIQChatResponseChunk))):
|
|
210
|
+
validated_message_type = WebSocketMessageType.RESPONSE_MESSAGE
|
|
211
|
+
|
|
212
|
+
elif (isinstance(data_model, AIQResponseIntermediateStep)):
|
|
213
|
+
validated_message_type = WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE
|
|
214
|
+
|
|
215
|
+
elif (isinstance(data_model, HumanPromptBase)):
|
|
216
|
+
validated_message_type = WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE
|
|
217
|
+
else:
|
|
218
|
+
raise ValueError("Data type not found")
|
|
219
|
+
|
|
220
|
+
return validated_message_type
|
|
221
|
+
|
|
222
|
+
except ValueError as e:
|
|
223
|
+
logger.error("Error type not found converting data to validated websocket message content: %s",
|
|
224
|
+
str(e),
|
|
225
|
+
exc_info=True)
|
|
226
|
+
return WebSocketMessageType.ERROR_MESSAGE
|
|
227
|
+
|
|
228
|
+
async def get_intermediate_step_parent_id(self, data_model: AIQResponseIntermediateStep) -> str:
|
|
229
|
+
"""
|
|
230
|
+
Retrieves intermediate step parent_id from AIQResponseIntermediateStep instance.
|
|
231
|
+
|
|
232
|
+
:param data_model: AIQResponseIntermediateStep Data Model instance.
|
|
233
|
+
:return: Intermediate step parent_id or "default".
|
|
234
|
+
"""
|
|
235
|
+
return data_model.parent_id or "root"
|
|
236
|
+
|
|
237
|
+
async def create_system_response_token_message( # pylint: disable=R0917:too-many-positional-arguments
|
|
238
|
+
self,
|
|
239
|
+
message_type: Literal[WebSocketMessageType.RESPONSE_MESSAGE,
|
|
240
|
+
WebSocketMessageType.ERROR_MESSAGE] = WebSocketMessageType.RESPONSE_MESSAGE,
|
|
241
|
+
message_id: str | None = str(uuid.uuid4()),
|
|
242
|
+
thread_id: str = "default",
|
|
243
|
+
parent_id: str = "default",
|
|
244
|
+
content: SystemResponseContent
|
|
245
|
+
| Error = SystemResponseContent(),
|
|
246
|
+
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
247
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
248
|
+
) -> WebSocketSystemResponseTokenMessage | None:
|
|
249
|
+
"""
|
|
250
|
+
Creates a system response token message with default values.
|
|
251
|
+
|
|
252
|
+
:param message_type: Type of WebSocket message.
|
|
253
|
+
:param message_id: Unique identifier for the message (default: generated UUID).
|
|
254
|
+
:param thread_id: ID of the thread the message belongs to (default: "default").
|
|
255
|
+
:param parent_id: ID of the user message that spawned child messages.
|
|
256
|
+
:param content: Message content.
|
|
257
|
+
:param status: Status of the message (default: IN_PROGRESS).
|
|
258
|
+
:param timestamp: Timestamp of the message (default: current UTC time).
|
|
259
|
+
:return: A WebSocketSystemResponseTokenMessage instance.
|
|
260
|
+
"""
|
|
261
|
+
try:
|
|
262
|
+
return WebSocketSystemResponseTokenMessage(type=message_type,
|
|
263
|
+
id=message_id,
|
|
264
|
+
thread_id=thread_id,
|
|
265
|
+
parent_id=parent_id,
|
|
266
|
+
content=content,
|
|
267
|
+
status=status,
|
|
268
|
+
timestamp=timestamp)
|
|
269
|
+
|
|
270
|
+
except Exception as e:
|
|
271
|
+
logger.error("Error creating system response token message: %s", str(e), exc_info=True)
|
|
272
|
+
return None
|
|
273
|
+
|
|
274
|
+
async def create_system_intermediate_step_message( # pylint: disable=R0917:too-many-positional-arguments
|
|
275
|
+
self,
|
|
276
|
+
message_type: Literal[WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE] = (
|
|
277
|
+
WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE),
|
|
278
|
+
message_id: str = str(uuid.uuid4()),
|
|
279
|
+
thread_id: str = "default",
|
|
280
|
+
parent_id: str = "default",
|
|
281
|
+
content: SystemIntermediateStepContent = SystemIntermediateStepContent(name="default", payload="default"),
|
|
282
|
+
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
283
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
284
|
+
) -> WebSocketSystemIntermediateStepMessage | None:
|
|
285
|
+
"""
|
|
286
|
+
Creates a system intermediate step message with default values.
|
|
287
|
+
|
|
288
|
+
:param message_type: Type of WebSocket message.
|
|
289
|
+
:param message_id: Unique identifier for the message (default: generated UUID).
|
|
290
|
+
:param thread_id: ID of the thread the message belongs to (default: "default").
|
|
291
|
+
:param parent_id: ID of the user message that spawned child messages.
|
|
292
|
+
:param content: Message content
|
|
293
|
+
:param status: Status of the message (default: IN_PROGRESS).
|
|
294
|
+
:param timestamp: Timestamp of the message (default: current UTC time).
|
|
295
|
+
:return: A WebSocketSystemIntermediateStepMessage instance.
|
|
296
|
+
"""
|
|
297
|
+
try:
|
|
298
|
+
return WebSocketSystemIntermediateStepMessage(type=message_type,
|
|
299
|
+
id=message_id,
|
|
300
|
+
thread_id=thread_id,
|
|
301
|
+
parent_id=parent_id,
|
|
302
|
+
content=content,
|
|
303
|
+
status=status,
|
|
304
|
+
timestamp=timestamp)
|
|
305
|
+
|
|
306
|
+
except Exception as e:
|
|
307
|
+
logger.error("Error creating system intermediate step message: %s", str(e), exc_info=True)
|
|
308
|
+
return None
|
|
309
|
+
|
|
310
|
+
async def create_system_interaction_message( # pylint: disable=R0917:too-many-positional-arguments
|
|
311
|
+
self,
|
|
312
|
+
*,
|
|
313
|
+
message_type: Literal[WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE] = (
|
|
314
|
+
WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE),
|
|
315
|
+
message_id: str | None = str(uuid.uuid4()),
|
|
316
|
+
thread_id: str = "default",
|
|
317
|
+
parent_id: str = "default",
|
|
318
|
+
content: HumanPrompt,
|
|
319
|
+
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
320
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
321
|
+
) -> WebSocketSystemInteractionMessage | None: # noqa: E125 continuation line with same indent as next logical line
|
|
322
|
+
"""
|
|
323
|
+
Creates a system interaction message with default values.
|
|
324
|
+
|
|
325
|
+
:param message_type: Type of WebSocket message.
|
|
326
|
+
:param message_id: Unique identifier for the message (default: generated UUID).
|
|
327
|
+
:param thread_id: ID of the thread the message belongs to (default: "default").
|
|
328
|
+
:param parent_id: ID of the user message that spawned child messages.
|
|
329
|
+
:param content: Message content
|
|
330
|
+
:param status: Status of the message (default: IN_PROGRESS).
|
|
331
|
+
:param timestamp: Timestamp of the message (default: current UTC time).
|
|
332
|
+
:return: A WebSocketSystemInteractionMessage instance.
|
|
333
|
+
"""
|
|
334
|
+
try:
|
|
335
|
+
return WebSocketSystemInteractionMessage(type=message_type,
|
|
336
|
+
id=message_id,
|
|
337
|
+
thread_id=thread_id,
|
|
338
|
+
parent_id=parent_id,
|
|
339
|
+
content=content,
|
|
340
|
+
status=status,
|
|
341
|
+
timestamp=timestamp)
|
|
342
|
+
|
|
343
|
+
except Exception as e:
|
|
344
|
+
logger.error("Error creating system interaction message: %s", str(e), exc_info=True)
|
|
345
|
+
return None
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from aiq.cli.register_workflow import register_front_end
|
|
17
|
+
from aiq.data_models.config import AIQConfig
|
|
18
|
+
from aiq.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register_front_end(config_type=FastApiFrontEndConfig)
|
|
22
|
+
async def register_fastapi_front_end(config: FastApiFrontEndConfig, full_config: AIQConfig):
|
|
23
|
+
from aiq.front_ends.fastapi.fastapi_front_end_plugin import FastApiFrontEndPlugin
|
|
24
|
+
|
|
25
|
+
yield FastApiFrontEndPlugin(full_config=full_config)
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import typing
|
|
18
|
+
from collections.abc import AsyncGenerator
|
|
19
|
+
|
|
20
|
+
from aiq.data_models.api_server import AIQResponseIntermediateStep
|
|
21
|
+
from aiq.data_models.api_server import AIQResponsePayloadOutput
|
|
22
|
+
from aiq.data_models.api_server import AIQResponseSerializable
|
|
23
|
+
from aiq.data_models.step_adaptor import StepAdaptorConfig
|
|
24
|
+
from aiq.front_ends.fastapi.intermediate_steps_subscriber import pull_intermediate
|
|
25
|
+
from aiq.front_ends.fastapi.step_adaptor import StepAdaptor
|
|
26
|
+
from aiq.runtime.session import AIQSessionManager
|
|
27
|
+
from aiq.utils.producer_consumer_queue import AsyncIOProducerConsumerQueue
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
async def generate_streaming_response_as_str(payload: typing.Any,
|
|
31
|
+
*,
|
|
32
|
+
session_manager: AIQSessionManager,
|
|
33
|
+
streaming: bool,
|
|
34
|
+
step_adaptor: StepAdaptor = StepAdaptor(StepAdaptorConfig()),
|
|
35
|
+
result_type: type | None = None,
|
|
36
|
+
output_type: type | None = None) -> AsyncGenerator[str]:
|
|
37
|
+
|
|
38
|
+
async for item in generate_streaming_response(payload,
|
|
39
|
+
session_manager=session_manager,
|
|
40
|
+
streaming=streaming,
|
|
41
|
+
step_adaptor=step_adaptor,
|
|
42
|
+
result_type=result_type,
|
|
43
|
+
output_type=output_type):
|
|
44
|
+
|
|
45
|
+
if (isinstance(item, AIQResponseSerializable)):
|
|
46
|
+
yield item.get_stream_data()
|
|
47
|
+
else:
|
|
48
|
+
raise ValueError("Unexpected item type in stream. Expected AIQChatResponseSerializable, got: " +
|
|
49
|
+
str(type(item)))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
async def generate_streaming_response(payload: typing.Any,
|
|
53
|
+
*,
|
|
54
|
+
session_manager: AIQSessionManager,
|
|
55
|
+
streaming: bool,
|
|
56
|
+
step_adaptor: StepAdaptor = StepAdaptor(StepAdaptorConfig()),
|
|
57
|
+
result_type: type | None = None,
|
|
58
|
+
output_type: type | None = None) -> AsyncGenerator[AIQResponseSerializable]:
|
|
59
|
+
|
|
60
|
+
async with session_manager.run(payload) as runner:
|
|
61
|
+
|
|
62
|
+
q: AsyncIOProducerConsumerQueue[AIQResponseSerializable] = AsyncIOProducerConsumerQueue()
|
|
63
|
+
|
|
64
|
+
# Start the intermediate stream
|
|
65
|
+
intermediate_complete = await pull_intermediate(q, step_adaptor)
|
|
66
|
+
|
|
67
|
+
async def pull_result():
|
|
68
|
+
if session_manager.workflow.has_streaming_output and streaming:
|
|
69
|
+
async for chunk in runner.result_stream(to_type=output_type):
|
|
70
|
+
await q.put(chunk)
|
|
71
|
+
else:
|
|
72
|
+
result = await runner.result(to_type=result_type)
|
|
73
|
+
await q.put(runner.convert(result, output_type))
|
|
74
|
+
|
|
75
|
+
# Wait until the intermediate subscription is done before closing q
|
|
76
|
+
# But we have no direct "intermediate_done" reference here
|
|
77
|
+
# because it's encapsulated in pull_intermediate. So we can do:
|
|
78
|
+
# await some_event.wait()
|
|
79
|
+
# If needed. Alternatively, you can skip that if the intermediate
|
|
80
|
+
# subscriber won't block the main flow.
|
|
81
|
+
#
|
|
82
|
+
# For example, if you *need* to guarantee the subscriber is done before
|
|
83
|
+
# closing the queue, you can structure the code to store or return
|
|
84
|
+
# the 'intermediate_done' event from pull_intermediate.
|
|
85
|
+
#
|
|
86
|
+
|
|
87
|
+
await intermediate_complete.wait()
|
|
88
|
+
|
|
89
|
+
await q.close()
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
# Start the result stream
|
|
93
|
+
asyncio.create_task(pull_result())
|
|
94
|
+
|
|
95
|
+
async for item in q:
|
|
96
|
+
|
|
97
|
+
if (isinstance(item, AIQResponseSerializable)):
|
|
98
|
+
yield item
|
|
99
|
+
else:
|
|
100
|
+
yield AIQResponsePayloadOutput(payload=item)
|
|
101
|
+
except Exception as e:
|
|
102
|
+
# Handle exceptions here
|
|
103
|
+
raise e
|
|
104
|
+
finally:
|
|
105
|
+
await q.close()
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
async def generate_single_response(
|
|
109
|
+
payload: typing.Any,
|
|
110
|
+
session_manager: AIQSessionManager,
|
|
111
|
+
result_type: type | None = None,
|
|
112
|
+
) -> typing.Any:
|
|
113
|
+
if (not session_manager.workflow.has_single_output):
|
|
114
|
+
raise ValueError("Cannot get a single output value for streaming workflows")
|
|
115
|
+
|
|
116
|
+
async with session_manager.run(payload) as runner:
|
|
117
|
+
return await runner.result(to_type=result_type)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
async def generate_streaming_response_full(payload: typing.Any,
|
|
121
|
+
*,
|
|
122
|
+
session_manager: AIQSessionManager,
|
|
123
|
+
streaming: bool,
|
|
124
|
+
result_type: type | None = None,
|
|
125
|
+
output_type: type | None = None,
|
|
126
|
+
filter_steps: str | None = None) -> AsyncGenerator[AIQResponseSerializable]:
|
|
127
|
+
"""
|
|
128
|
+
Similar to generate_streaming_response but provides raw AIQResponseIntermediateStep objects
|
|
129
|
+
without any step adaptor translations.
|
|
130
|
+
"""
|
|
131
|
+
# Parse filter_steps into a set of allowed types if provided
|
|
132
|
+
# Special case: if filter_steps is "none", suppress all steps
|
|
133
|
+
allowed_types = None
|
|
134
|
+
if filter_steps:
|
|
135
|
+
if filter_steps.lower() == "none":
|
|
136
|
+
allowed_types = set() # Empty set means no steps allowed
|
|
137
|
+
else:
|
|
138
|
+
allowed_types = set(filter_steps.split(','))
|
|
139
|
+
|
|
140
|
+
async with session_manager.run(payload) as runner:
|
|
141
|
+
q: AsyncIOProducerConsumerQueue[AIQResponseSerializable] = AsyncIOProducerConsumerQueue()
|
|
142
|
+
|
|
143
|
+
# Start the intermediate stream without step adaptor
|
|
144
|
+
intermediate_complete = await pull_intermediate(q, None)
|
|
145
|
+
|
|
146
|
+
async def pull_result():
|
|
147
|
+
if session_manager.workflow.has_streaming_output and streaming:
|
|
148
|
+
async for chunk in runner.result_stream(to_type=output_type):
|
|
149
|
+
await q.put(chunk)
|
|
150
|
+
else:
|
|
151
|
+
result = await runner.result(to_type=result_type)
|
|
152
|
+
await q.put(runner.convert(result, output_type))
|
|
153
|
+
|
|
154
|
+
await intermediate_complete.wait()
|
|
155
|
+
await q.close()
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
# Start the result stream
|
|
159
|
+
asyncio.create_task(pull_result())
|
|
160
|
+
|
|
161
|
+
async for item in q:
|
|
162
|
+
if (isinstance(item, AIQResponseIntermediateStep)):
|
|
163
|
+
# Filter intermediate steps if filter_steps is provided
|
|
164
|
+
if allowed_types is None or item.type in allowed_types:
|
|
165
|
+
yield item
|
|
166
|
+
else:
|
|
167
|
+
yield AIQResponsePayloadOutput(payload=item)
|
|
168
|
+
except Exception as e:
|
|
169
|
+
# Handle exceptions here
|
|
170
|
+
raise e
|
|
171
|
+
finally:
|
|
172
|
+
await q.close()
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
async def generate_streaming_response_full_as_str(payload: typing.Any,
|
|
176
|
+
*,
|
|
177
|
+
session_manager: AIQSessionManager,
|
|
178
|
+
streaming: bool,
|
|
179
|
+
result_type: type | None = None,
|
|
180
|
+
output_type: type | None = None,
|
|
181
|
+
filter_steps: str | None = None) -> AsyncGenerator[str]:
|
|
182
|
+
"""
|
|
183
|
+
Similar to generate_streaming_response but converts the response to a string format.
|
|
184
|
+
"""
|
|
185
|
+
async for item in generate_streaming_response_full(payload,
|
|
186
|
+
session_manager=session_manager,
|
|
187
|
+
streaming=streaming,
|
|
188
|
+
result_type=result_type,
|
|
189
|
+
output_type=output_type,
|
|
190
|
+
filter_steps=filter_steps):
|
|
191
|
+
if (isinstance(item, AIQResponseIntermediateStep) or isinstance(item, AIQResponsePayloadOutput)):
|
|
192
|
+
yield item.get_stream_data()
|
|
193
|
+
else:
|
|
194
|
+
raise ValueError("Unexpected item type in stream. Expected AIQChatResponseSerializable, got: " +
|
|
195
|
+
str(type(item)))
|