nvidia-nat 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +256 -0
- nat/agent/dual_node.py +67 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +363 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +149 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +225 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +415 -0
- nat/agent/rewoo_agent/prompt.py +110 -0
- nat/agent/rewoo_agent/register.py +157 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +119 -0
- nat/agent/tool_calling_agent/register.py +106 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +93 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +21 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +285 -0
- nat/builder/component_utils.py +316 -0
- nat/builder/context.py +270 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +161 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +24 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +344 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +627 -0
- nat/builder/intermediate_step_manager.py +174 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +148 -0
- nat/builder/workflow_builder.py +1117 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +37 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +129 -0
- nat/cli/commands/info/list_mcp.py +304 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +155 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +246 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- nat/cli/commands/workflow/templates/register.py.j2 +5 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +317 -0
- nat/cli/entrypoint.py +135 -0
- nat/cli/main.py +57 -0
- nat/cli/register_workflow.py +488 -0
- nat/cli/type_registry.py +1000 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/api_server.py +716 -0
- nat/data_models/authentication.py +231 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +58 -0
- nat/data_models/component_ref.py +168 -0
- nat/data_models/config.py +410 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +127 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +30 -0
- nat/data_models/function_dependencies.py +72 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +190 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +43 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +60 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +367 -0
- nat/eval/evaluate.py +510 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +45 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +23 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +184 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +134 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +66 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +36 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +233 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +96 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +183 -0
- nat/front_ends/fastapi/main.py +72 -0
- nat/front_ends/fastapi/message_handler.py +320 -0
- nat/front_ends/fastapi/message_validator.py +352 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/mcp_front_end_config.py +36 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +241 -0
- nat/front_ends/register.py +22 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +57 -0
- nat/llm/nim_llm.py +46 -0
- nat/llm/openai_llm.py +46 -0
- nat/llm/register.py +23 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +94 -0
- nat/llm/utils/error.py +17 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +20 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +322 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +288 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/type_introspection_mixin.py +183 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +310 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +71 -0
- nat/observability/register.py +96 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +627 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +290 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +131 -0
- nat/profiler/decorators/function_tracking.py +254 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/profile_runner.py +473 -0
- nat/profiler/utils.py +184 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +571 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +251 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +21 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +237 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +22 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +195 -0
- nat/runtime/session.py +162 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +318 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +74 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +42 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools/__init__.py +0 -0
- nat/tool/github_tools/create_github_commit.py +133 -0
- nat/tool/github_tools/create_github_issue.py +87 -0
- nat/tool/github_tools/create_github_pr.py +106 -0
- nat/tool/github_tools/get_github_file.py +106 -0
- nat/tool/github_tools/get_github_issue.py +166 -0
- nat/tool/github_tools/get_github_pr.py +256 -0
- nat/tool/github_tools/update_github_issue.py +100 -0
- nat/tool/mcp/__init__.py +14 -0
- nat/tool/mcp/exceptions.py +142 -0
- nat/tool/mcp/mcp_client.py +255 -0
- nat/tool/mcp/mcp_tool.py +96 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +67 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +38 -0
- nat/tool/retriever.py +94 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +289 -0
- nat/utils/exception_handlers/mcp.py +211 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +197 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +290 -0
- nat/utils/type_utils.py +484 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0.dist-info/METADATA +365 -0
- nvidia_nat-1.2.0.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
import typing
|
|
19
|
+
import uuid
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
from fastapi import WebSocket
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
from pydantic import ValidationError
|
|
25
|
+
from starlette.websockets import WebSocketDisconnect
|
|
26
|
+
|
|
27
|
+
from nat.authentication.interfaces import FlowHandlerBase
|
|
28
|
+
from nat.data_models.api_server import ChatResponse
|
|
29
|
+
from nat.data_models.api_server import ChatResponseChunk
|
|
30
|
+
from nat.data_models.api_server import Error
|
|
31
|
+
from nat.data_models.api_server import ErrorTypes
|
|
32
|
+
from nat.data_models.api_server import ResponsePayloadOutput
|
|
33
|
+
from nat.data_models.api_server import ResponseSerializable
|
|
34
|
+
from nat.data_models.api_server import SystemResponseContent
|
|
35
|
+
from nat.data_models.api_server import TextContent
|
|
36
|
+
from nat.data_models.api_server import WebSocketMessageStatus
|
|
37
|
+
from nat.data_models.api_server import WebSocketMessageType
|
|
38
|
+
from nat.data_models.api_server import WebSocketSystemInteractionMessage
|
|
39
|
+
from nat.data_models.api_server import WebSocketSystemIntermediateStepMessage
|
|
40
|
+
from nat.data_models.api_server import WebSocketSystemResponseTokenMessage
|
|
41
|
+
from nat.data_models.api_server import WebSocketUserInteractionResponseMessage
|
|
42
|
+
from nat.data_models.api_server import WebSocketUserMessage
|
|
43
|
+
from nat.data_models.api_server import WorkflowSchemaType
|
|
44
|
+
from nat.data_models.interactive import HumanPromptNotification
|
|
45
|
+
from nat.data_models.interactive import HumanResponse
|
|
46
|
+
from nat.data_models.interactive import HumanResponseNotification
|
|
47
|
+
from nat.data_models.interactive import InteractionPrompt
|
|
48
|
+
from nat.front_ends.fastapi.message_validator import MessageValidator
|
|
49
|
+
from nat.front_ends.fastapi.response_helpers import generate_streaming_response
|
|
50
|
+
from nat.front_ends.fastapi.step_adaptor import StepAdaptor
|
|
51
|
+
from nat.runtime.session import SessionManager
|
|
52
|
+
|
|
53
|
+
logger = logging.getLogger(__name__)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class WebSocketMessageHandler:
|
|
57
|
+
|
|
58
|
+
def __init__(self, socket: WebSocket, session_manager: SessionManager, step_adaptor: StepAdaptor):
|
|
59
|
+
self._socket: WebSocket = socket
|
|
60
|
+
self._session_manager: SessionManager = session_manager
|
|
61
|
+
self._step_adaptor: StepAdaptor = step_adaptor
|
|
62
|
+
|
|
63
|
+
self._message_validator: MessageValidator = MessageValidator()
|
|
64
|
+
self._running_workflow_task: asyncio.Task | None = None
|
|
65
|
+
self._message_parent_id: str = "default_id"
|
|
66
|
+
self._conversation_id: str | None = None
|
|
67
|
+
self._workflow_schema_type: str = None
|
|
68
|
+
self._user_interaction_response: asyncio.Future[HumanResponse] | None = None
|
|
69
|
+
|
|
70
|
+
self._flow_handler: FlowHandlerBase | None = None
|
|
71
|
+
|
|
72
|
+
self._schema_output_mapping: dict[str, type[BaseModel] | None] = {
|
|
73
|
+
WorkflowSchemaType.GENERATE: self._session_manager.workflow.single_output_schema,
|
|
74
|
+
WorkflowSchemaType.CHAT: ChatResponse,
|
|
75
|
+
WorkflowSchemaType.CHAT_STREAM: ChatResponseChunk,
|
|
76
|
+
WorkflowSchemaType.GENERATE_STREAM: self._session_manager.workflow.streaming_output_schema,
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
def set_flow_handler(self, flow_handler: FlowHandlerBase) -> None:
|
|
80
|
+
self._flow_handler = flow_handler
|
|
81
|
+
|
|
82
|
+
async def __aenter__(self) -> "WebSocketMessageHandler":
|
|
83
|
+
await self._socket.accept()
|
|
84
|
+
|
|
85
|
+
return self
|
|
86
|
+
|
|
87
|
+
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
|
88
|
+
|
|
89
|
+
# TODO: Handle the exit # pylint: disable=fixme
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
async def run(self) -> None:
|
|
93
|
+
"""
|
|
94
|
+
Processes received messages from websocket and routes them appropriately.
|
|
95
|
+
"""
|
|
96
|
+
while True:
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
|
|
100
|
+
message: dict[str, Any] = await self._socket.receive_json()
|
|
101
|
+
|
|
102
|
+
validated_message: BaseModel = await self._message_validator.validate_message(message)
|
|
103
|
+
|
|
104
|
+
# Received a request to start a workflow
|
|
105
|
+
if (isinstance(validated_message, WebSocketUserMessage)):
|
|
106
|
+
await self.process_workflow_request(validated_message)
|
|
107
|
+
|
|
108
|
+
elif isinstance(
|
|
109
|
+
validated_message,
|
|
110
|
+
( # noqa: E131
|
|
111
|
+
WebSocketSystemResponseTokenMessage,
|
|
112
|
+
WebSocketSystemIntermediateStepMessage,
|
|
113
|
+
WebSocketSystemInteractionMessage)):
|
|
114
|
+
# These messages are already handled by self.create_websocket_message(data_model=value, …)
|
|
115
|
+
# No further processing is needed here.
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
elif (isinstance(validated_message, WebSocketUserInteractionResponseMessage)):
|
|
119
|
+
user_content = await self.process_user_message_content(validated_message)
|
|
120
|
+
self._user_interaction_response.set_result(user_content)
|
|
121
|
+
except (asyncio.CancelledError, WebSocketDisconnect):
|
|
122
|
+
# TODO: Handle the disconnect # pylint: disable=fixme
|
|
123
|
+
break
|
|
124
|
+
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
async def process_user_message_content(
|
|
128
|
+
self, user_content: WebSocketUserMessage | WebSocketUserInteractionResponseMessage) -> BaseModel | None:
|
|
129
|
+
"""
|
|
130
|
+
Processes the contents of a user message.
|
|
131
|
+
|
|
132
|
+
:param user_content: Incoming content data model.
|
|
133
|
+
:return: A validated Pydantic user content model or None if not found.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
for user_message in user_content.content.messages[::-1]:
|
|
137
|
+
if (user_message.role == "user"):
|
|
138
|
+
|
|
139
|
+
for attachment in user_message.content:
|
|
140
|
+
|
|
141
|
+
if isinstance(attachment, TextContent):
|
|
142
|
+
return attachment
|
|
143
|
+
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
async def process_workflow_request(self, user_message_as_validated_type: WebSocketUserMessage) -> None:
|
|
147
|
+
"""
|
|
148
|
+
Process user messages and routes them appropriately.
|
|
149
|
+
|
|
150
|
+
:param user_message_as_validated_type: A WebSocketUserMessage Data Model instance.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
self._message_parent_id = user_message_as_validated_type.id
|
|
155
|
+
self._workflow_schema_type = user_message_as_validated_type.schema_type
|
|
156
|
+
self._conversation_id = user_message_as_validated_type.conversation_id
|
|
157
|
+
|
|
158
|
+
content: BaseModel | None = await self.process_user_message_content(user_message_as_validated_type)
|
|
159
|
+
|
|
160
|
+
if content is None:
|
|
161
|
+
raise ValueError(f"User message content could not be found: {user_message_as_validated_type}")
|
|
162
|
+
|
|
163
|
+
if isinstance(content, TextContent) and (self._running_workflow_task is None):
|
|
164
|
+
|
|
165
|
+
def _done_callback(task: asyncio.Task): # pylint: disable=unused-argument
|
|
166
|
+
self._running_workflow_task = None
|
|
167
|
+
|
|
168
|
+
self._running_workflow_task = asyncio.create_task(
|
|
169
|
+
self._run_workflow(content.text,
|
|
170
|
+
self._conversation_id,
|
|
171
|
+
result_type=self._schema_output_mapping[self._workflow_schema_type],
|
|
172
|
+
output_type=self._schema_output_mapping[
|
|
173
|
+
self._workflow_schema_type])).add_done_callback(_done_callback)
|
|
174
|
+
|
|
175
|
+
except ValueError as e:
|
|
176
|
+
logger.error("User message content not found: %s", str(e), exc_info=True)
|
|
177
|
+
await self.create_websocket_message(data_model=Error(code=ErrorTypes.INVALID_USER_MESSAGE_CONTENT,
|
|
178
|
+
message="User message content could not be found",
|
|
179
|
+
details=str(e)),
|
|
180
|
+
message_type=WebSocketMessageType.ERROR_MESSAGE,
|
|
181
|
+
status=WebSocketMessageStatus.IN_PROGRESS)
|
|
182
|
+
|
|
183
|
+
async def create_websocket_message(self,
|
|
184
|
+
data_model: BaseModel,
|
|
185
|
+
message_type: str | None = None,
|
|
186
|
+
status: str = WebSocketMessageStatus.IN_PROGRESS) -> None:
|
|
187
|
+
"""
|
|
188
|
+
Creates a websocket message that will be ready for routing based on message type or data model.
|
|
189
|
+
|
|
190
|
+
:param data_model: Message content model.
|
|
191
|
+
:param message_type: Message content model.
|
|
192
|
+
:param status: Message content model.
|
|
193
|
+
"""
|
|
194
|
+
try:
|
|
195
|
+
message: BaseModel | None = None
|
|
196
|
+
|
|
197
|
+
if message_type is None:
|
|
198
|
+
message_type = await self._message_validator.resolve_message_type_by_data(data_model)
|
|
199
|
+
|
|
200
|
+
message_schema: type[BaseModel] = await self._message_validator.get_message_schema_by_type(message_type)
|
|
201
|
+
|
|
202
|
+
if 'id' in data_model.model_fields:
|
|
203
|
+
message_id: str = data_model.id
|
|
204
|
+
else:
|
|
205
|
+
message_id = str(uuid.uuid4())
|
|
206
|
+
|
|
207
|
+
content: BaseModel = await self._message_validator.convert_data_to_message_content(data_model)
|
|
208
|
+
|
|
209
|
+
if issubclass(message_schema, WebSocketSystemResponseTokenMessage):
|
|
210
|
+
message = await self._message_validator.create_system_response_token_message(
|
|
211
|
+
message_id=message_id,
|
|
212
|
+
parent_id=self._message_parent_id,
|
|
213
|
+
conversation_id=self._conversation_id,
|
|
214
|
+
content=content,
|
|
215
|
+
status=status)
|
|
216
|
+
|
|
217
|
+
elif issubclass(message_schema, WebSocketSystemIntermediateStepMessage):
|
|
218
|
+
message = await self._message_validator.create_system_intermediate_step_message(
|
|
219
|
+
message_id=message_id,
|
|
220
|
+
parent_id=await self._message_validator.get_intermediate_step_parent_id(data_model),
|
|
221
|
+
conversation_id=self._conversation_id,
|
|
222
|
+
content=content,
|
|
223
|
+
status=status)
|
|
224
|
+
|
|
225
|
+
elif issubclass(message_schema, WebSocketSystemInteractionMessage):
|
|
226
|
+
message = await self._message_validator.create_system_interaction_message(
|
|
227
|
+
message_id=message_id,
|
|
228
|
+
parent_id=self._message_parent_id,
|
|
229
|
+
conversation_id=self._conversation_id,
|
|
230
|
+
content=content,
|
|
231
|
+
status=status)
|
|
232
|
+
|
|
233
|
+
elif isinstance(content, Error):
|
|
234
|
+
raise ValidationError(f"Invalid input data creating websocket message. {data_model.model_dump_json()}")
|
|
235
|
+
|
|
236
|
+
elif issubclass(message_schema, Error):
|
|
237
|
+
raise TypeError(f"Invalid message type: {message_type}")
|
|
238
|
+
|
|
239
|
+
elif (message is None):
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"Message type could not be resolved by input data model: {data_model.model_dump_json()}")
|
|
242
|
+
|
|
243
|
+
except (ValidationError, TypeError, ValueError) as e:
|
|
244
|
+
logger.error("A data vaidation error ocurred creating websocket message: %s", str(e), exc_info=True)
|
|
245
|
+
message = await self._message_validator.create_system_response_token_message(
|
|
246
|
+
message_type=WebSocketMessageType.ERROR_MESSAGE,
|
|
247
|
+
conversation_id=self._conversation_id,
|
|
248
|
+
content=Error(code=ErrorTypes.UNKNOWN_ERROR, message="default", details=str(e)))
|
|
249
|
+
|
|
250
|
+
finally:
|
|
251
|
+
if (message is not None):
|
|
252
|
+
await self._socket.send_json(message.model_dump())
|
|
253
|
+
|
|
254
|
+
async def human_interaction_callback(self, prompt: InteractionPrompt) -> HumanResponse:
|
|
255
|
+
"""
|
|
256
|
+
Registered human interaction callback that processes human interactions and returns
|
|
257
|
+
responses from websocket connection.
|
|
258
|
+
|
|
259
|
+
:param prompt: Incoming interaction content data model.
|
|
260
|
+
:return: A Text Content Base Pydantic model.
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
# First create a future from the loop for the human response
|
|
264
|
+
human_response_future: asyncio.Future[HumanResponse] = asyncio.get_running_loop().create_future()
|
|
265
|
+
|
|
266
|
+
# Then add the future to the outstanding human prompts dictionary
|
|
267
|
+
self._user_interaction_response = human_response_future
|
|
268
|
+
|
|
269
|
+
try:
|
|
270
|
+
|
|
271
|
+
await self.create_websocket_message(data_model=prompt.content,
|
|
272
|
+
message_type=WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE,
|
|
273
|
+
status=WebSocketMessageStatus.IN_PROGRESS)
|
|
274
|
+
|
|
275
|
+
if (isinstance(prompt.content, HumanPromptNotification)):
|
|
276
|
+
|
|
277
|
+
return HumanResponseNotification()
|
|
278
|
+
|
|
279
|
+
# Wait for the human response future to complete
|
|
280
|
+
interaction_response: HumanResponse = await human_response_future
|
|
281
|
+
|
|
282
|
+
interaction_response: HumanResponse = await self._message_validator.convert_text_content_to_human_response(
|
|
283
|
+
interaction_response, prompt.content)
|
|
284
|
+
|
|
285
|
+
return interaction_response
|
|
286
|
+
|
|
287
|
+
finally:
|
|
288
|
+
# Delete the future from the outstanding human prompts dictionary
|
|
289
|
+
self._user_interaction_response = None
|
|
290
|
+
|
|
291
|
+
async def _run_workflow(self,
|
|
292
|
+
payload: typing.Any,
|
|
293
|
+
conversation_id: str | None = None,
|
|
294
|
+
result_type: type | None = None,
|
|
295
|
+
output_type: type | None = None) -> None:
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
async with self._session_manager.session(
|
|
299
|
+
conversation_id=conversation_id,
|
|
300
|
+
request=self._socket,
|
|
301
|
+
user_input_callback=self.human_interaction_callback,
|
|
302
|
+
user_authentication_callback=(self._flow_handler.authenticate
|
|
303
|
+
if self._flow_handler else None)) as session:
|
|
304
|
+
|
|
305
|
+
async for value in generate_streaming_response(payload,
|
|
306
|
+
session_manager=session,
|
|
307
|
+
streaming=True,
|
|
308
|
+
step_adaptor=self._step_adaptor,
|
|
309
|
+
result_type=result_type,
|
|
310
|
+
output_type=output_type):
|
|
311
|
+
|
|
312
|
+
if not isinstance(value, ResponseSerializable):
|
|
313
|
+
value = ResponsePayloadOutput(payload=value)
|
|
314
|
+
|
|
315
|
+
await self.create_websocket_message(data_model=value, status=WebSocketMessageStatus.IN_PROGRESS)
|
|
316
|
+
|
|
317
|
+
finally:
|
|
318
|
+
await self.create_websocket_message(data_model=SystemResponseContent(),
|
|
319
|
+
message_type=WebSocketMessageType.RESPONSE_MESSAGE,
|
|
320
|
+
status=WebSocketMessageStatus.COMPLETE)
|
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import datetime
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import uuid
|
|
20
|
+
from typing import Any
|
|
21
|
+
from typing import Literal
|
|
22
|
+
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
from pydantic import ValidationError
|
|
25
|
+
|
|
26
|
+
from nat.data_models.api_server import ChatResponse
|
|
27
|
+
from nat.data_models.api_server import ChatResponseChunk
|
|
28
|
+
from nat.data_models.api_server import Error
|
|
29
|
+
from nat.data_models.api_server import ErrorTypes
|
|
30
|
+
from nat.data_models.api_server import ResponseIntermediateStep
|
|
31
|
+
from nat.data_models.api_server import ResponsePayloadOutput
|
|
32
|
+
from nat.data_models.api_server import SystemIntermediateStepContent
|
|
33
|
+
from nat.data_models.api_server import SystemResponseContent
|
|
34
|
+
from nat.data_models.api_server import TextContent
|
|
35
|
+
from nat.data_models.api_server import WebSocketMessageStatus
|
|
36
|
+
from nat.data_models.api_server import WebSocketMessageType
|
|
37
|
+
from nat.data_models.api_server import WebSocketSystemInteractionMessage
|
|
38
|
+
from nat.data_models.api_server import WebSocketSystemIntermediateStepMessage
|
|
39
|
+
from nat.data_models.api_server import WebSocketSystemResponseTokenMessage
|
|
40
|
+
from nat.data_models.api_server import WebSocketUserInteractionResponseMessage
|
|
41
|
+
from nat.data_models.api_server import WebSocketUserMessage
|
|
42
|
+
from nat.data_models.interactive import BinaryHumanPromptOption
|
|
43
|
+
from nat.data_models.interactive import HumanPrompt
|
|
44
|
+
from nat.data_models.interactive import HumanPromptBase
|
|
45
|
+
from nat.data_models.interactive import HumanPromptBinary
|
|
46
|
+
from nat.data_models.interactive import HumanPromptCheckbox
|
|
47
|
+
from nat.data_models.interactive import HumanPromptDropdown
|
|
48
|
+
from nat.data_models.interactive import HumanPromptRadio
|
|
49
|
+
from nat.data_models.interactive import HumanPromptText
|
|
50
|
+
from nat.data_models.interactive import HumanResponse
|
|
51
|
+
from nat.data_models.interactive import HumanResponseBinary
|
|
52
|
+
from nat.data_models.interactive import HumanResponseCheckbox
|
|
53
|
+
from nat.data_models.interactive import HumanResponseDropdown
|
|
54
|
+
from nat.data_models.interactive import HumanResponseRadio
|
|
55
|
+
from nat.data_models.interactive import HumanResponseText
|
|
56
|
+
from nat.data_models.interactive import MultipleChoiceOption
|
|
57
|
+
|
|
58
|
+
logger = logging.getLogger(__name__)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class MessageValidator:
|
|
62
|
+
|
|
63
|
+
def __init__(self):
|
|
64
|
+
self._message_type_schema_mapping: dict[str, type[BaseModel]] = {
|
|
65
|
+
WebSocketMessageType.USER_MESSAGE: WebSocketUserMessage,
|
|
66
|
+
WebSocketMessageType.RESPONSE_MESSAGE: WebSocketSystemResponseTokenMessage,
|
|
67
|
+
WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE: WebSocketSystemIntermediateStepMessage,
|
|
68
|
+
WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE: WebSocketSystemInteractionMessage,
|
|
69
|
+
WebSocketMessageType.USER_INTERACTION_MESSAGE: WebSocketUserInteractionResponseMessage,
|
|
70
|
+
WebSocketMessageType.ERROR_MESSAGE: Error
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
self._message_parent_id: str = "default_id"
|
|
74
|
+
|
|
75
|
+
async def validate_message(self, message: dict[str, Any]) -> BaseModel:
|
|
76
|
+
"""
|
|
77
|
+
Validates an incoming WebSocket message against its expected schema.
|
|
78
|
+
If validation fails, returns a system response error message.
|
|
79
|
+
|
|
80
|
+
:param message: Incoming WebSocket message as a dictionary.
|
|
81
|
+
:return: A validated Pydantic model.
|
|
82
|
+
"""
|
|
83
|
+
validated_message: BaseModel
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
message_type = message.get("type")
|
|
87
|
+
if not message_type:
|
|
88
|
+
raise ValueError(f"Missing message type: {json.dumps(message)}")
|
|
89
|
+
|
|
90
|
+
schema: type[BaseModel] = await self.get_message_schema_by_type(message_type)
|
|
91
|
+
|
|
92
|
+
if issubclass(schema, Error):
|
|
93
|
+
raise TypeError(
|
|
94
|
+
f"An error was encountered processing an incoming WebSocket message of type: {message_type}")
|
|
95
|
+
|
|
96
|
+
validated_message = schema(**message)
|
|
97
|
+
return validated_message
|
|
98
|
+
|
|
99
|
+
except (ValidationError, TypeError, ValueError) as e:
|
|
100
|
+
logger.error("A data validation error %s occurred for message: %s", str(e), str(message), exc_info=True)
|
|
101
|
+
return await self.create_system_response_token_message(message_type=WebSocketMessageType.ERROR_MESSAGE,
|
|
102
|
+
content=Error(code=ErrorTypes.INVALID_MESSAGE,
|
|
103
|
+
message="Error validating message.",
|
|
104
|
+
details=str(e)))
|
|
105
|
+
|
|
106
|
+
async def get_message_schema_by_type(self, message_type: str) -> type[BaseModel]:
|
|
107
|
+
"""
|
|
108
|
+
Retrieves the corresponding Pydantic model schema based on the message type.
|
|
109
|
+
|
|
110
|
+
:param message_type: The type of message as a string.
|
|
111
|
+
:return: A Pydantic schema class if found, otherwise None.
|
|
112
|
+
"""
|
|
113
|
+
try:
|
|
114
|
+
schema: type[BaseModel] | None = self._message_type_schema_mapping.get(message_type)
|
|
115
|
+
|
|
116
|
+
if schema is None:
|
|
117
|
+
raise ValueError(f"Unknown message type: {message_type}")
|
|
118
|
+
|
|
119
|
+
return schema
|
|
120
|
+
|
|
121
|
+
except (TypeError, ValueError) as e:
|
|
122
|
+
logger.error("Error retrieving schema for message type '%s': %s", message_type, str(e), exc_info=True)
|
|
123
|
+
return Error
|
|
124
|
+
|
|
125
|
+
async def convert_data_to_message_content(self, data_model: BaseModel) -> BaseModel:
|
|
126
|
+
"""
|
|
127
|
+
Converts a Pydantic data model to a WebSocket message content instance.
|
|
128
|
+
|
|
129
|
+
:param data_model: Pydantic Data Model instance.
|
|
130
|
+
:return: A WebSocket Message Content Data Model instance.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
validated_message_content: BaseModel = None
|
|
134
|
+
try:
|
|
135
|
+
if (isinstance(data_model, ResponsePayloadOutput)):
|
|
136
|
+
if hasattr(data_model.payload, 'model_dump_json'):
|
|
137
|
+
text_content: str = data_model.payload.model_dump_json()
|
|
138
|
+
else:
|
|
139
|
+
text_content: str = str(data_model.payload)
|
|
140
|
+
validated_message_content = SystemResponseContent(text=text_content)
|
|
141
|
+
|
|
142
|
+
elif (isinstance(data_model, (ChatResponse, ChatResponseChunk))):
|
|
143
|
+
validated_message_content = SystemResponseContent(text=data_model.choices[0].message.content)
|
|
144
|
+
|
|
145
|
+
elif (isinstance(data_model, ResponseIntermediateStep)):
|
|
146
|
+
validated_message_content = SystemIntermediateStepContent(name=data_model.name,
|
|
147
|
+
payload=data_model.payload)
|
|
148
|
+
elif (isinstance(data_model, HumanPromptBase)):
|
|
149
|
+
validated_message_content = data_model
|
|
150
|
+
elif (isinstance(data_model, SystemResponseContent)):
|
|
151
|
+
return data_model
|
|
152
|
+
else:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"Input data could not be converted to validated message content: {data_model.model_dump_json()}")
|
|
155
|
+
|
|
156
|
+
return validated_message_content
|
|
157
|
+
|
|
158
|
+
except ValueError as e:
|
|
159
|
+
logger.error("Input data could not be converted to validated message content: %s", str(e), exc_info=True)
|
|
160
|
+
return Error(code=ErrorTypes.INVALID_DATA_CONTENT, message="Input data not supported.", details=str(e))
|
|
161
|
+
|
|
162
|
+
async def convert_text_content_to_human_response(self, text_content: TextContent,
|
|
163
|
+
human_prompt: HumanPromptBase) -> HumanResponse:
|
|
164
|
+
"""
|
|
165
|
+
Converts Message Text Content data model to a Human Response Base data model instance.
|
|
166
|
+
|
|
167
|
+
:param text_content: Pydantic TextContent Data Model instance.
|
|
168
|
+
:param human_prompt: Pydantic HumanPrompt Data Model instance.
|
|
169
|
+
:return: A Human Response Data Model instance.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
human_response: HumanResponse = None
|
|
173
|
+
try:
|
|
174
|
+
if (isinstance(human_prompt, HumanPromptText)):
|
|
175
|
+
human_response = HumanResponseText(text=text_content.text)
|
|
176
|
+
|
|
177
|
+
elif (isinstance(human_prompt, HumanPromptBinary)):
|
|
178
|
+
human_response = HumanResponseBinary(selected_option=BinaryHumanPromptOption(value=text_content.text))
|
|
179
|
+
|
|
180
|
+
elif (isinstance(human_prompt, HumanPromptRadio)):
|
|
181
|
+
human_response = HumanResponseRadio(selected_option=MultipleChoiceOption(value=text_content.text))
|
|
182
|
+
|
|
183
|
+
elif (isinstance(human_prompt, HumanPromptCheckbox)):
|
|
184
|
+
human_response = HumanResponseCheckbox(selected_option=MultipleChoiceOption(value=text_content.text))
|
|
185
|
+
|
|
186
|
+
elif (isinstance(human_prompt, HumanPromptDropdown)):
|
|
187
|
+
human_response = HumanResponseDropdown(selected_option=MultipleChoiceOption(value=text_content.text))
|
|
188
|
+
else:
|
|
189
|
+
raise ValueError("Message content type not found")
|
|
190
|
+
|
|
191
|
+
return human_response
|
|
192
|
+
|
|
193
|
+
except ValueError as e:
|
|
194
|
+
logger.error("Error human response content not found: %s", str(e), exc_info=True)
|
|
195
|
+
return HumanResponseText(text=str(e))
|
|
196
|
+
|
|
197
|
+
async def resolve_message_type_by_data(self, data_model: BaseModel) -> str:
|
|
198
|
+
"""
|
|
199
|
+
Resolve message type from a validated model
|
|
200
|
+
|
|
201
|
+
:param data_model: Pydantic Data Model instance.
|
|
202
|
+
:return: A WebSocket Message Content Data Model instance.
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
validated_message_type: str = ""
|
|
206
|
+
try:
|
|
207
|
+
if (isinstance(data_model, (ResponsePayloadOutput, ChatResponse, ChatResponseChunk))):
|
|
208
|
+
validated_message_type = WebSocketMessageType.RESPONSE_MESSAGE
|
|
209
|
+
|
|
210
|
+
elif (isinstance(data_model, ResponseIntermediateStep)):
|
|
211
|
+
validated_message_type = WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE
|
|
212
|
+
|
|
213
|
+
elif (isinstance(data_model, HumanPromptBase)):
|
|
214
|
+
validated_message_type = WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE
|
|
215
|
+
else:
|
|
216
|
+
raise ValueError("Data type not found")
|
|
217
|
+
|
|
218
|
+
return validated_message_type
|
|
219
|
+
|
|
220
|
+
except ValueError as e:
|
|
221
|
+
logger.error("Error type not found converting data to validated websocket message content: %s",
|
|
222
|
+
str(e),
|
|
223
|
+
exc_info=True)
|
|
224
|
+
return WebSocketMessageType.ERROR_MESSAGE
|
|
225
|
+
|
|
226
|
+
async def get_intermediate_step_parent_id(self, data_model: ResponseIntermediateStep) -> str:
|
|
227
|
+
"""
|
|
228
|
+
Retrieves intermediate step parent_id from ResponseIntermediateStep instance.
|
|
229
|
+
|
|
230
|
+
:param data_model: ResponseIntermediateStep Data Model instance.
|
|
231
|
+
:return: Intermediate step parent_id or "default".
|
|
232
|
+
"""
|
|
233
|
+
return data_model.parent_id or "root"
|
|
234
|
+
|
|
235
|
+
async def create_system_response_token_message( # pylint: disable=R0917:too-many-positional-arguments
|
|
236
|
+
self,
|
|
237
|
+
message_type: Literal[WebSocketMessageType.RESPONSE_MESSAGE,
|
|
238
|
+
WebSocketMessageType.ERROR_MESSAGE] = WebSocketMessageType.RESPONSE_MESSAGE,
|
|
239
|
+
message_id: str | None = str(uuid.uuid4()),
|
|
240
|
+
thread_id: str = "default",
|
|
241
|
+
parent_id: str = "default",
|
|
242
|
+
conversation_id: str | None = None,
|
|
243
|
+
content: SystemResponseContent
|
|
244
|
+
| Error = SystemResponseContent(),
|
|
245
|
+
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
246
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
247
|
+
) -> WebSocketSystemResponseTokenMessage | None:
|
|
248
|
+
"""
|
|
249
|
+
Creates a system response token message with default values.
|
|
250
|
+
|
|
251
|
+
:param message_type: Type of WebSocket message.
|
|
252
|
+
:param message_id: Unique identifier for the message (default: generated UUID).
|
|
253
|
+
:param thread_id: ID of the thread the message belongs to (default: "default").
|
|
254
|
+
:param parent_id: ID of the user message that spawned child messages.
|
|
255
|
+
:param conversation_id: ID of the conversation this message belongs to (default: None).
|
|
256
|
+
:param content: Message content.
|
|
257
|
+
:param status: Status of the message (default: IN_PROGRESS).
|
|
258
|
+
:param timestamp: Timestamp of the message (default: current UTC time).
|
|
259
|
+
:return: A WebSocketSystemResponseTokenMessage instance.
|
|
260
|
+
"""
|
|
261
|
+
try:
|
|
262
|
+
return WebSocketSystemResponseTokenMessage(type=message_type,
|
|
263
|
+
id=message_id,
|
|
264
|
+
thread_id=thread_id,
|
|
265
|
+
parent_id=parent_id,
|
|
266
|
+
conversation_id=conversation_id,
|
|
267
|
+
content=content,
|
|
268
|
+
status=status,
|
|
269
|
+
timestamp=timestamp)
|
|
270
|
+
|
|
271
|
+
except Exception as e:
|
|
272
|
+
logger.error("Error creating system response token message: %s", str(e), exc_info=True)
|
|
273
|
+
return None
|
|
274
|
+
|
|
275
|
+
async def create_system_intermediate_step_message( # pylint: disable=R0917:too-many-positional-arguments
|
|
276
|
+
self,
|
|
277
|
+
message_type: Literal[WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE] = (
|
|
278
|
+
WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE),
|
|
279
|
+
message_id: str = str(uuid.uuid4()),
|
|
280
|
+
thread_id: str = "default",
|
|
281
|
+
parent_id: str = "default",
|
|
282
|
+
conversation_id: str | None = None,
|
|
283
|
+
content: SystemIntermediateStepContent = SystemIntermediateStepContent(name="default", payload="default"),
|
|
284
|
+
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
285
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
286
|
+
) -> WebSocketSystemIntermediateStepMessage | None:
|
|
287
|
+
"""
|
|
288
|
+
Creates a system intermediate step message with default values.
|
|
289
|
+
|
|
290
|
+
:param message_type: Type of WebSocket message.
|
|
291
|
+
:param message_id: Unique identifier for the message (default: generated UUID).
|
|
292
|
+
:param thread_id: ID of the thread the message belongs to (default: "default").
|
|
293
|
+
:param parent_id: ID of the user message that spawned child messages.
|
|
294
|
+
:param conversation_id: ID of the conversation this message belongs to (default: None).
|
|
295
|
+
:param content: Message content
|
|
296
|
+
:param status: Status of the message (default: IN_PROGRESS).
|
|
297
|
+
:param timestamp: Timestamp of the message (default: current UTC time).
|
|
298
|
+
:return: A WebSocketSystemIntermediateStepMessage instance.
|
|
299
|
+
"""
|
|
300
|
+
try:
|
|
301
|
+
return WebSocketSystemIntermediateStepMessage(type=message_type,
|
|
302
|
+
id=message_id,
|
|
303
|
+
thread_id=thread_id,
|
|
304
|
+
parent_id=parent_id,
|
|
305
|
+
conversation_id=conversation_id,
|
|
306
|
+
content=content,
|
|
307
|
+
status=status,
|
|
308
|
+
timestamp=timestamp)
|
|
309
|
+
|
|
310
|
+
except Exception as e:
|
|
311
|
+
logger.error("Error creating system intermediate step message: %s", str(e), exc_info=True)
|
|
312
|
+
return None
|
|
313
|
+
|
|
314
|
+
async def create_system_interaction_message( # pylint: disable=R0917:too-many-positional-arguments
|
|
315
|
+
self,
|
|
316
|
+
*,
|
|
317
|
+
message_type: Literal[WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE] = (
|
|
318
|
+
WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE),
|
|
319
|
+
message_id: str | None = str(uuid.uuid4()),
|
|
320
|
+
thread_id: str = "default",
|
|
321
|
+
parent_id: str = "default",
|
|
322
|
+
conversation_id: str | None = None,
|
|
323
|
+
content: HumanPrompt,
|
|
324
|
+
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
325
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
326
|
+
) -> WebSocketSystemInteractionMessage | None: # noqa: E125 continuation line with same indent as next logical line
|
|
327
|
+
"""
|
|
328
|
+
Creates a system interaction message with default values.
|
|
329
|
+
|
|
330
|
+
:param message_type: Type of WebSocket message.
|
|
331
|
+
:param message_id: Unique identifier for the message (default: generated UUID).
|
|
332
|
+
:param thread_id: ID of the thread the message belongs to (default: "default").
|
|
333
|
+
:param parent_id: ID of the user message that spawned child messages.
|
|
334
|
+
:param conversation_id: ID of the conversation this message belongs to (default: None).
|
|
335
|
+
:param content: Message content
|
|
336
|
+
:param status: Status of the message (default: IN_PROGRESS).
|
|
337
|
+
:param timestamp: Timestamp of the message (default: current UTC time).
|
|
338
|
+
:return: A WebSocketSystemInteractionMessage instance.
|
|
339
|
+
"""
|
|
340
|
+
try:
|
|
341
|
+
return WebSocketSystemInteractionMessage(type=message_type,
|
|
342
|
+
id=message_id,
|
|
343
|
+
thread_id=thread_id,
|
|
344
|
+
parent_id=parent_id,
|
|
345
|
+
conversation_id=conversation_id,
|
|
346
|
+
content=content,
|
|
347
|
+
status=status,
|
|
348
|
+
timestamp=timestamp)
|
|
349
|
+
|
|
350
|
+
except Exception as e:
|
|
351
|
+
logger.error("Error creating system interaction message: %s", str(e), exc_info=True)
|
|
352
|
+
return None
|