nvidia-nat 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +256 -0
- nat/agent/dual_node.py +67 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +363 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +149 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +225 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +415 -0
- nat/agent/rewoo_agent/prompt.py +110 -0
- nat/agent/rewoo_agent/register.py +157 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +119 -0
- nat/agent/tool_calling_agent/register.py +106 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +93 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +21 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +285 -0
- nat/builder/component_utils.py +316 -0
- nat/builder/context.py +270 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +161 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +24 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +344 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +627 -0
- nat/builder/intermediate_step_manager.py +174 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +148 -0
- nat/builder/workflow_builder.py +1117 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +37 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +129 -0
- nat/cli/commands/info/list_mcp.py +304 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +155 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +246 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- nat/cli/commands/workflow/templates/register.py.j2 +5 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +317 -0
- nat/cli/entrypoint.py +135 -0
- nat/cli/main.py +57 -0
- nat/cli/register_workflow.py +488 -0
- nat/cli/type_registry.py +1000 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/api_server.py +716 -0
- nat/data_models/authentication.py +231 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +58 -0
- nat/data_models/component_ref.py +168 -0
- nat/data_models/config.py +410 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +127 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +30 -0
- nat/data_models/function_dependencies.py +72 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +190 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +43 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +60 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +367 -0
- nat/eval/evaluate.py +510 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +45 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +23 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +184 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +134 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +66 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +36 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +233 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +96 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +183 -0
- nat/front_ends/fastapi/main.py +72 -0
- nat/front_ends/fastapi/message_handler.py +320 -0
- nat/front_ends/fastapi/message_validator.py +352 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/mcp_front_end_config.py +36 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +241 -0
- nat/front_ends/register.py +22 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +57 -0
- nat/llm/nim_llm.py +46 -0
- nat/llm/openai_llm.py +46 -0
- nat/llm/register.py +23 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +94 -0
- nat/llm/utils/error.py +17 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +20 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +322 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +288 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/type_introspection_mixin.py +183 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +310 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +71 -0
- nat/observability/register.py +96 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +627 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +290 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +131 -0
- nat/profiler/decorators/function_tracking.py +254 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/profile_runner.py +473 -0
- nat/profiler/utils.py +184 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +571 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +251 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +21 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +237 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +22 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +195 -0
- nat/runtime/session.py +162 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +318 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +74 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +42 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools/__init__.py +0 -0
- nat/tool/github_tools/create_github_commit.py +133 -0
- nat/tool/github_tools/create_github_issue.py +87 -0
- nat/tool/github_tools/create_github_pr.py +106 -0
- nat/tool/github_tools/get_github_file.py +106 -0
- nat/tool/github_tools/get_github_issue.py +166 -0
- nat/tool/github_tools/get_github_pr.py +256 -0
- nat/tool/github_tools/update_github_issue.py +100 -0
- nat/tool/mcp/__init__.py +14 -0
- nat/tool/mcp/exceptions.py +142 -0
- nat/tool/mcp/mcp_client.py +255 -0
- nat/tool/mcp/mcp_tool.py +96 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +67 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +38 -0
- nat/tool/retriever.py +94 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +289 -0
- nat/utils/exception_handlers/mcp.py +211 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +197 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +290 -0
- nat/utils/type_utils.py +484 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0.dist-info/METADATA +365 -0
- nvidia_nat-1.2.0.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,716 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import abc
|
|
17
|
+
import datetime
|
|
18
|
+
import typing
|
|
19
|
+
import uuid
|
|
20
|
+
from abc import abstractmethod
|
|
21
|
+
from enum import Enum
|
|
22
|
+
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
from pydantic import ConfigDict
|
|
25
|
+
from pydantic import Discriminator
|
|
26
|
+
from pydantic import Field
|
|
27
|
+
from pydantic import HttpUrl
|
|
28
|
+
from pydantic import conlist
|
|
29
|
+
from pydantic import field_serializer
|
|
30
|
+
from pydantic import field_validator
|
|
31
|
+
from pydantic_core.core_schema import ValidationInfo
|
|
32
|
+
|
|
33
|
+
from nat.data_models.interactive import HumanPrompt
|
|
34
|
+
from nat.utils.type_converter import GlobalTypeConverter
|
|
35
|
+
|
|
36
|
+
FINISH_REASONS = frozenset({'stop', 'length', 'tool_calls', 'content_filter', 'function_call'})
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Request(BaseModel):
|
|
40
|
+
"""
|
|
41
|
+
Request is a data model that represents HTTP request attributes.
|
|
42
|
+
"""
|
|
43
|
+
model_config = ConfigDict(extra="forbid")
|
|
44
|
+
|
|
45
|
+
method: str | None = Field(default=None,
|
|
46
|
+
description="HTTP method used for the request (e.g., GET, POST, PUT, DELETE).")
|
|
47
|
+
url_path: str | None = Field(default=None, description="URL request path.")
|
|
48
|
+
url_port: int | None = Field(default=None, description="URL request port number.")
|
|
49
|
+
url_scheme: str | None = Field(default=None, description="URL scheme indicating the protocol (e.g., http, https).")
|
|
50
|
+
headers: typing.Any | None = Field(default=None, description="HTTP headers associated with the request.")
|
|
51
|
+
query_params: typing.Any | None = Field(default=None, description="Query parameters included in the request URL.")
|
|
52
|
+
path_params: dict[str, str] | None = Field(default=None,
|
|
53
|
+
description="Path parameters extracted from the request URL.")
|
|
54
|
+
client_host: str | None = Field(default=None, description="Client host address from which the request originated.")
|
|
55
|
+
client_port: int | None = Field(default=None, description="Client port number from which the request originated.")
|
|
56
|
+
cookies: dict[str, str] | None = Field(
|
|
57
|
+
default=None, description="Cookies sent with the request, stored in a dictionary-like object.")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ChatContentType(str, Enum):
|
|
61
|
+
"""
|
|
62
|
+
ChatContentType is an Enum that represents the type of Chat content.
|
|
63
|
+
"""
|
|
64
|
+
TEXT = "text"
|
|
65
|
+
IMAGE_URL = "image_url"
|
|
66
|
+
INPUT_AUDIO = "input_audio"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class InputAudio(BaseModel):
|
|
70
|
+
data: str = "default"
|
|
71
|
+
format: str = "default"
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class AudioContent(BaseModel):
|
|
75
|
+
model_config = ConfigDict(extra="forbid")
|
|
76
|
+
|
|
77
|
+
type: typing.Literal[ChatContentType.INPUT_AUDIO] = ChatContentType.INPUT_AUDIO
|
|
78
|
+
input_audio: InputAudio = InputAudio()
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class ImageUrl(BaseModel):
|
|
82
|
+
url: HttpUrl = HttpUrl(url="http://default.com")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class ImageContent(BaseModel):
|
|
86
|
+
model_config = ConfigDict(extra="forbid")
|
|
87
|
+
|
|
88
|
+
type: typing.Literal[ChatContentType.IMAGE_URL] = ChatContentType.IMAGE_URL
|
|
89
|
+
image_url: ImageUrl = ImageUrl()
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class TextContent(BaseModel):
|
|
93
|
+
model_config = ConfigDict(extra="forbid")
|
|
94
|
+
|
|
95
|
+
type: typing.Literal[ChatContentType.TEXT] = ChatContentType.TEXT
|
|
96
|
+
text: str = "default"
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class Security(BaseModel):
|
|
100
|
+
model_config = ConfigDict(extra="forbid")
|
|
101
|
+
|
|
102
|
+
api_key: str = "default"
|
|
103
|
+
token: str = "default"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
UserContent = typing.Annotated[TextContent | ImageContent | AudioContent, Discriminator("type")]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class Message(BaseModel):
|
|
110
|
+
content: str | list[UserContent]
|
|
111
|
+
role: str
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class ChatRequest(BaseModel):
|
|
115
|
+
"""
|
|
116
|
+
ChatRequest is a data model that represents a request to the NAT chat API.
|
|
117
|
+
Fully compatible with OpenAI Chat Completions API specification.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
# Required fields
|
|
121
|
+
messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
|
|
122
|
+
|
|
123
|
+
# Optional fields (OpenAI Chat Completions API compatible)
|
|
124
|
+
model: str | None = Field(default=None, description="name of the model to use")
|
|
125
|
+
frequency_penalty: float | None = Field(default=0.0,
|
|
126
|
+
description="Penalty for new tokens based on frequency in text")
|
|
127
|
+
logit_bias: dict[str, float] | None = Field(default=None,
|
|
128
|
+
description="Modify likelihood of specified tokens appearing")
|
|
129
|
+
logprobs: bool | None = Field(default=None, description="Whether to return log probabilities")
|
|
130
|
+
top_logprobs: int | None = Field(default=None, description="Number of most likely tokens to return")
|
|
131
|
+
max_tokens: int | None = Field(default=None, description="Maximum number of tokens to generate")
|
|
132
|
+
n: int | None = Field(default=1, description="Number of chat completion choices to generate")
|
|
133
|
+
presence_penalty: float | None = Field(default=0.0, description="Penalty for new tokens based on presence in text")
|
|
134
|
+
response_format: dict[str, typing.Any] | None = Field(default=None, description="Response format specification")
|
|
135
|
+
seed: int | None = Field(default=None, description="Random seed for deterministic sampling")
|
|
136
|
+
service_tier: typing.Literal["auto", "default"] | None = Field(default=None,
|
|
137
|
+
description="Service tier for the request")
|
|
138
|
+
stream: bool | None = Field(default=False, description="Whether to stream partial message deltas")
|
|
139
|
+
stream_options: dict[str, typing.Any] | None = Field(default=None, description="Options for streaming")
|
|
140
|
+
temperature: float | None = Field(default=1.0, description="Sampling temperature between 0 and 2")
|
|
141
|
+
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
|
|
142
|
+
tools: list[dict[str, typing.Any]] | None = Field(default=None, description="List of tools the model may call")
|
|
143
|
+
tool_choice: str | dict[str, typing.Any] | None = Field(default=None, description="Controls which tool is called")
|
|
144
|
+
parallel_tool_calls: bool | None = Field(default=True, description="Whether to enable parallel function calling")
|
|
145
|
+
user: str | None = Field(default=None, description="Unique identifier representing end-user")
|
|
146
|
+
|
|
147
|
+
model_config = ConfigDict(extra="allow",
|
|
148
|
+
json_schema_extra={
|
|
149
|
+
"example": {
|
|
150
|
+
"model": "nvidia/nemotron",
|
|
151
|
+
"messages": [{
|
|
152
|
+
"role": "user", "content": "who are you?"
|
|
153
|
+
}],
|
|
154
|
+
"temperature": 0.7,
|
|
155
|
+
"stream": False
|
|
156
|
+
}
|
|
157
|
+
})
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def from_string(data: str,
|
|
161
|
+
*,
|
|
162
|
+
model: str | None = None,
|
|
163
|
+
temperature: float | None = None,
|
|
164
|
+
max_tokens: int | None = None,
|
|
165
|
+
top_p: float | None = None) -> "ChatRequest":
|
|
166
|
+
|
|
167
|
+
return ChatRequest(messages=[Message(content=data, role="user")],
|
|
168
|
+
model=model,
|
|
169
|
+
temperature=temperature,
|
|
170
|
+
max_tokens=max_tokens,
|
|
171
|
+
top_p=top_p)
|
|
172
|
+
|
|
173
|
+
@staticmethod
|
|
174
|
+
def from_content(content: list[UserContent],
|
|
175
|
+
*,
|
|
176
|
+
model: str | None = None,
|
|
177
|
+
temperature: float | None = None,
|
|
178
|
+
max_tokens: int | None = None,
|
|
179
|
+
top_p: float | None = None) -> "ChatRequest":
|
|
180
|
+
|
|
181
|
+
return ChatRequest(messages=[Message(content=content, role="user")],
|
|
182
|
+
model=model,
|
|
183
|
+
temperature=temperature,
|
|
184
|
+
max_tokens=max_tokens,
|
|
185
|
+
top_p=top_p)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class ChoiceMessage(BaseModel):
|
|
189
|
+
content: str | None = None
|
|
190
|
+
role: str | None = None
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class ChoiceDelta(BaseModel):
|
|
194
|
+
"""Delta object for streaming responses (OpenAI-compatible)"""
|
|
195
|
+
content: str | None = None
|
|
196
|
+
role: str | None = None
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class Choice(BaseModel):
|
|
200
|
+
model_config = ConfigDict(extra="allow")
|
|
201
|
+
|
|
202
|
+
message: ChoiceMessage | None = None
|
|
203
|
+
delta: ChoiceDelta | None = None
|
|
204
|
+
finish_reason: typing.Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] | None = None
|
|
205
|
+
index: int
|
|
206
|
+
# logprobs: ChoiceLogprobs | None = None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class Usage(BaseModel):
|
|
210
|
+
prompt_tokens: int
|
|
211
|
+
completion_tokens: int
|
|
212
|
+
total_tokens: int
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class ResponseSerializable(abc.ABC):
|
|
216
|
+
"""
|
|
217
|
+
ResponseSerializable is an abstract class that defines the interface for serializing output for the NAT
|
|
218
|
+
Toolkit chat streaming API.
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
@abstractmethod
|
|
222
|
+
def get_stream_data(self) -> str:
|
|
223
|
+
pass
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class ResponseBaseModelOutput(BaseModel, ResponseSerializable):
|
|
227
|
+
|
|
228
|
+
def get_stream_data(self) -> str:
|
|
229
|
+
return f"data: {self.model_dump_json()}\n\n"
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class ResponseBaseModelIntermediate(BaseModel, ResponseSerializable):
|
|
233
|
+
|
|
234
|
+
def get_stream_data(self) -> str:
|
|
235
|
+
return f"intermediate_data: {self.model_dump_json()}\n\n"
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class ChatResponse(ResponseBaseModelOutput):
|
|
239
|
+
"""
|
|
240
|
+
ChatResponse is a data model that represents a response from the NAT chat API.
|
|
241
|
+
Fully compatible with OpenAI Chat Completions API specification.
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
# Allow extra fields in the model_config to support derived models
|
|
245
|
+
model_config = ConfigDict(extra="allow")
|
|
246
|
+
id: str
|
|
247
|
+
object: str = "chat.completion"
|
|
248
|
+
model: str = ""
|
|
249
|
+
created: datetime.datetime
|
|
250
|
+
choices: list[Choice]
|
|
251
|
+
usage: Usage | None = None
|
|
252
|
+
system_fingerprint: str | None = None
|
|
253
|
+
service_tier: typing.Literal["scale", "default"] | None = None
|
|
254
|
+
|
|
255
|
+
@field_serializer('created')
|
|
256
|
+
def serialize_created(self, created: datetime.datetime) -> int:
|
|
257
|
+
"""Serialize datetime to Unix timestamp for OpenAI compatibility"""
|
|
258
|
+
return int(created.timestamp())
|
|
259
|
+
|
|
260
|
+
@staticmethod
|
|
261
|
+
def from_string(data: str,
|
|
262
|
+
*,
|
|
263
|
+
id_: str | None = None,
|
|
264
|
+
object_: str | None = None,
|
|
265
|
+
model: str | None = None,
|
|
266
|
+
created: datetime.datetime | None = None,
|
|
267
|
+
usage: Usage | None = None) -> "ChatResponse":
|
|
268
|
+
|
|
269
|
+
if id_ is None:
|
|
270
|
+
id_ = str(uuid.uuid4())
|
|
271
|
+
if object_ is None:
|
|
272
|
+
object_ = "chat.completion"
|
|
273
|
+
if model is None:
|
|
274
|
+
model = ""
|
|
275
|
+
if created is None:
|
|
276
|
+
created = datetime.datetime.now(datetime.timezone.utc)
|
|
277
|
+
|
|
278
|
+
return ChatResponse(id=id_,
|
|
279
|
+
object=object_,
|
|
280
|
+
model=model,
|
|
281
|
+
created=created,
|
|
282
|
+
choices=[Choice(index=0, message=ChoiceMessage(content=data), finish_reason="stop")],
|
|
283
|
+
usage=usage)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class ChatResponseChunk(ResponseBaseModelOutput):
|
|
287
|
+
"""
|
|
288
|
+
ChatResponseChunk is a data model that represents a response chunk from the NAT chat streaming API.
|
|
289
|
+
Fully compatible with OpenAI Chat Completions API specification.
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
# Allow extra fields in the model_config to support derived models
|
|
293
|
+
model_config = ConfigDict(extra="allow")
|
|
294
|
+
|
|
295
|
+
id: str
|
|
296
|
+
choices: list[Choice]
|
|
297
|
+
created: datetime.datetime
|
|
298
|
+
model: str = ""
|
|
299
|
+
object: str = "chat.completion.chunk"
|
|
300
|
+
system_fingerprint: str | None = None
|
|
301
|
+
service_tier: typing.Literal["scale", "default"] | None = None
|
|
302
|
+
usage: Usage | None = None
|
|
303
|
+
|
|
304
|
+
@field_serializer('created')
|
|
305
|
+
def serialize_created(self, created: datetime.datetime) -> int:
|
|
306
|
+
"""Serialize datetime to Unix timestamp for OpenAI compatibility"""
|
|
307
|
+
return int(created.timestamp())
|
|
308
|
+
|
|
309
|
+
@staticmethod
|
|
310
|
+
def from_string(data: str,
|
|
311
|
+
*,
|
|
312
|
+
id_: str | None = None,
|
|
313
|
+
created: datetime.datetime | None = None,
|
|
314
|
+
model: str | None = None,
|
|
315
|
+
object_: str | None = None) -> "ChatResponseChunk":
|
|
316
|
+
|
|
317
|
+
if id_ is None:
|
|
318
|
+
id_ = str(uuid.uuid4())
|
|
319
|
+
if created is None:
|
|
320
|
+
created = datetime.datetime.now(datetime.timezone.utc)
|
|
321
|
+
if model is None:
|
|
322
|
+
model = ""
|
|
323
|
+
if object_ is None:
|
|
324
|
+
object_ = "chat.completion.chunk"
|
|
325
|
+
|
|
326
|
+
return ChatResponseChunk(id=id_,
|
|
327
|
+
choices=[Choice(index=0, message=ChoiceMessage(content=data), finish_reason="stop")],
|
|
328
|
+
created=created,
|
|
329
|
+
model=model,
|
|
330
|
+
object=object_)
|
|
331
|
+
|
|
332
|
+
@staticmethod
|
|
333
|
+
def create_streaming_chunk(content: str,
|
|
334
|
+
*,
|
|
335
|
+
id_: str | None = None,
|
|
336
|
+
created: datetime.datetime | None = None,
|
|
337
|
+
model: str | None = None,
|
|
338
|
+
role: str | None = None,
|
|
339
|
+
finish_reason: str | None = None,
|
|
340
|
+
usage: Usage | None = None,
|
|
341
|
+
system_fingerprint: str | None = None) -> "ChatResponseChunk":
|
|
342
|
+
"""Create an OpenAI-compatible streaming chunk"""
|
|
343
|
+
if id_ is None:
|
|
344
|
+
id_ = str(uuid.uuid4())
|
|
345
|
+
if created is None:
|
|
346
|
+
created = datetime.datetime.now(datetime.timezone.utc)
|
|
347
|
+
if model is None:
|
|
348
|
+
model = ""
|
|
349
|
+
|
|
350
|
+
delta = ChoiceDelta(content=content, role=role) if content is not None or role is not None else ChoiceDelta()
|
|
351
|
+
|
|
352
|
+
final_finish_reason = finish_reason if finish_reason in FINISH_REASONS else None
|
|
353
|
+
|
|
354
|
+
return ChatResponseChunk(
|
|
355
|
+
id=id_,
|
|
356
|
+
choices=[Choice(index=0, message=None, delta=delta, finish_reason=final_finish_reason)],
|
|
357
|
+
created=created,
|
|
358
|
+
model=model,
|
|
359
|
+
object="chat.completion.chunk",
|
|
360
|
+
usage=usage,
|
|
361
|
+
system_fingerprint=system_fingerprint)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class ResponseIntermediateStep(ResponseBaseModelIntermediate):
|
|
365
|
+
"""
|
|
366
|
+
ResponseSerializedStep is a data model that represents a serialized step in the NAT chat streaming API.
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
# Allow extra fields in the model_config to support derived models
|
|
370
|
+
model_config = ConfigDict(extra="allow")
|
|
371
|
+
|
|
372
|
+
id: str
|
|
373
|
+
parent_id: str | None = None
|
|
374
|
+
type: str = "markdown"
|
|
375
|
+
name: str
|
|
376
|
+
payload: str
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
class ResponsePayloadOutput(BaseModel, ResponseSerializable):
|
|
380
|
+
|
|
381
|
+
payload: typing.Any
|
|
382
|
+
|
|
383
|
+
def get_stream_data(self) -> str:
|
|
384
|
+
|
|
385
|
+
if (isinstance(self.payload, BaseModel)):
|
|
386
|
+
return f"data: {self.payload.model_dump_json()}\n\n"
|
|
387
|
+
|
|
388
|
+
return f"data: {self.payload}\n\n"
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class GenerateResponse(BaseModel):
|
|
392
|
+
# Allow extra fields in the model_config to support derived models
|
|
393
|
+
model_config = ConfigDict(extra="allow")
|
|
394
|
+
|
|
395
|
+
# (fixme) define the intermediate step model
|
|
396
|
+
intermediate_steps: list[tuple] | None = None
|
|
397
|
+
output: str
|
|
398
|
+
value: str | None = "default"
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
class UserMessageContentRoleType(str, Enum):
|
|
402
|
+
USER = "user"
|
|
403
|
+
ASSISTANT = "assistant"
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
class WebSocketMessageType(str, Enum):
|
|
407
|
+
"""
|
|
408
|
+
WebSocketMessageType is an Enum that represents WebSocket Message types.
|
|
409
|
+
"""
|
|
410
|
+
USER_MESSAGE = "user_message"
|
|
411
|
+
RESPONSE_MESSAGE = "system_response_message"
|
|
412
|
+
INTERMEDIATE_STEP_MESSAGE = "system_intermediate_message"
|
|
413
|
+
SYSTEM_INTERACTION_MESSAGE = "system_interaction_message"
|
|
414
|
+
USER_INTERACTION_MESSAGE = "user_interaction_message"
|
|
415
|
+
ERROR_MESSAGE = "error_message"
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
class WorkflowSchemaType(str, Enum):
|
|
419
|
+
"""
|
|
420
|
+
WorkflowSchemaType is an Enum that represents Workkflow response types.
|
|
421
|
+
"""
|
|
422
|
+
GENERATE_STREAM = "generate_stream"
|
|
423
|
+
CHAT_STREAM = "chat_stream"
|
|
424
|
+
GENERATE = "generate"
|
|
425
|
+
CHAT = "chat"
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
class WebSocketMessageStatus(str, Enum):
|
|
429
|
+
"""
|
|
430
|
+
WebSocketMessageStatus is an Enum that represents the status of a WebSocket message.
|
|
431
|
+
"""
|
|
432
|
+
IN_PROGRESS = "in_progress"
|
|
433
|
+
COMPLETE = "complete"
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
class UserMessages(BaseModel):
|
|
437
|
+
model_config = ConfigDict(extra="forbid")
|
|
438
|
+
|
|
439
|
+
role: UserMessageContentRoleType
|
|
440
|
+
content: list[UserContent]
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class UserMessageContent(BaseModel):
|
|
444
|
+
model_config = ConfigDict(extra="forbid")
|
|
445
|
+
messages: list[UserMessages]
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class User(BaseModel):
|
|
449
|
+
model_config = ConfigDict(extra="forbid")
|
|
450
|
+
|
|
451
|
+
name: str = "default"
|
|
452
|
+
email: str = "default"
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
class ErrorTypes(str, Enum):
|
|
456
|
+
UNKNOWN_ERROR = "unknown_error"
|
|
457
|
+
INVALID_MESSAGE = "invalid_message"
|
|
458
|
+
INVALID_MESSAGE_TYPE = "invalid_message_type"
|
|
459
|
+
INVALID_USER_MESSAGE_CONTENT = "invalid_user_message_content"
|
|
460
|
+
INVALID_DATA_CONTENT = "invalid_data_content"
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
class Error(BaseModel):
|
|
464
|
+
model_config = ConfigDict(extra="forbid")
|
|
465
|
+
|
|
466
|
+
code: ErrorTypes = ErrorTypes.UNKNOWN_ERROR
|
|
467
|
+
message: str = "default"
|
|
468
|
+
details: str = "default"
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
class WebSocketUserMessage(BaseModel):
|
|
472
|
+
"""
|
|
473
|
+
For more details, refer to the API documentation:
|
|
474
|
+
docs/source/developer_guide/websockets.md
|
|
475
|
+
"""
|
|
476
|
+
# Allow extra fields in the model_config to support derived models
|
|
477
|
+
model_config = ConfigDict(extra="allow")
|
|
478
|
+
|
|
479
|
+
type: typing.Literal[WebSocketMessageType.USER_MESSAGE]
|
|
480
|
+
schema_type: WorkflowSchemaType
|
|
481
|
+
id: str = "default"
|
|
482
|
+
conversation_id: str | None = None
|
|
483
|
+
content: UserMessageContent
|
|
484
|
+
user: User = User()
|
|
485
|
+
security: Security = Security()
|
|
486
|
+
error: Error = Error()
|
|
487
|
+
schema_version: str = "1.0.0"
|
|
488
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
class WebSocketUserInteractionResponseMessage(BaseModel):
|
|
492
|
+
"""
|
|
493
|
+
For more details, refer to the API documentation:
|
|
494
|
+
docs/source/developer_guide/websockets.md
|
|
495
|
+
"""
|
|
496
|
+
type: typing.Literal[WebSocketMessageType.USER_INTERACTION_MESSAGE]
|
|
497
|
+
id: str = "default"
|
|
498
|
+
thread_id: str = "default"
|
|
499
|
+
content: UserMessageContent
|
|
500
|
+
user: User = User()
|
|
501
|
+
security: Security = Security()
|
|
502
|
+
error: Error = Error()
|
|
503
|
+
schema_version: str = "1.0.0"
|
|
504
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
class SystemIntermediateStepContent(BaseModel):
|
|
508
|
+
model_config = ConfigDict(extra="forbid")
|
|
509
|
+
name: str
|
|
510
|
+
payload: str
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
class WebSocketSystemIntermediateStepMessage(BaseModel):
|
|
514
|
+
"""
|
|
515
|
+
For more details, refer to the API documentation:
|
|
516
|
+
docs/source/developer_guide/websockets.md
|
|
517
|
+
"""
|
|
518
|
+
# Allow extra fields in the model_config to support derived models
|
|
519
|
+
model_config = ConfigDict(extra="allow")
|
|
520
|
+
|
|
521
|
+
type: typing.Literal[WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE]
|
|
522
|
+
id: str = "default"
|
|
523
|
+
thread_id: str | None = "default"
|
|
524
|
+
parent_id: str = "default"
|
|
525
|
+
intermediate_parent_id: str | None = "default"
|
|
526
|
+
update_message_id: str | None = "default"
|
|
527
|
+
conversation_id: str | None = None
|
|
528
|
+
content: SystemIntermediateStepContent
|
|
529
|
+
status: WebSocketMessageStatus
|
|
530
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
class SystemResponseContent(BaseModel):
|
|
534
|
+
model_config = ConfigDict(extra="forbid")
|
|
535
|
+
|
|
536
|
+
text: str | None = None
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
class WebSocketSystemResponseTokenMessage(BaseModel):
|
|
540
|
+
"""
|
|
541
|
+
For more details, refer to the API documentation:
|
|
542
|
+
docs/source/developer_guide/websockets.md
|
|
543
|
+
"""
|
|
544
|
+
# Allow extra fields in the model_config to support derived models
|
|
545
|
+
model_config = ConfigDict(extra="allow")
|
|
546
|
+
|
|
547
|
+
type: typing.Literal[WebSocketMessageType.RESPONSE_MESSAGE, WebSocketMessageType.ERROR_MESSAGE]
|
|
548
|
+
id: str | None = "default"
|
|
549
|
+
thread_id: str | None = "default"
|
|
550
|
+
parent_id: str = "default"
|
|
551
|
+
conversation_id: str | None = None
|
|
552
|
+
content: SystemResponseContent | Error | GenerateResponse
|
|
553
|
+
status: WebSocketMessageStatus
|
|
554
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
555
|
+
|
|
556
|
+
@field_validator("content")
|
|
557
|
+
@classmethod
|
|
558
|
+
def validate_content_by_type(cls, value: SystemResponseContent | Error | GenerateResponse, info: ValidationInfo):
|
|
559
|
+
if info.data.get("type") == WebSocketMessageType.ERROR_MESSAGE and not isinstance(value, Error):
|
|
560
|
+
raise ValueError(f"Field: content must be 'Error' when type is {WebSocketMessageType.ERROR_MESSAGE}")
|
|
561
|
+
|
|
562
|
+
if info.data.get("type") == WebSocketMessageType.RESPONSE_MESSAGE and not isinstance(
|
|
563
|
+
value, (SystemResponseContent, GenerateResponse)):
|
|
564
|
+
raise ValueError(
|
|
565
|
+
f"Field: content must be 'SystemResponseContent' when type is {WebSocketMessageType.RESPONSE_MESSAGE}")
|
|
566
|
+
return value
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
class WebSocketSystemInteractionMessage(BaseModel):
|
|
570
|
+
"""
|
|
571
|
+
For more details, refer to the API documentation:
|
|
572
|
+
docs/source/developer_guide/websockets.md
|
|
573
|
+
"""
|
|
574
|
+
# Allow extra fields in the model_config to support derived models
|
|
575
|
+
model_config = ConfigDict(extra="allow")
|
|
576
|
+
|
|
577
|
+
type: typing.Literal[
|
|
578
|
+
WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE] = WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE
|
|
579
|
+
id: str | None = "default"
|
|
580
|
+
thread_id: str | None = "default"
|
|
581
|
+
parent_id: str = "default"
|
|
582
|
+
conversation_id: str | None = None
|
|
583
|
+
content: HumanPrompt
|
|
584
|
+
status: WebSocketMessageStatus
|
|
585
|
+
timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
# ======== GenerateResponse Converters ========
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
def _generate_response_to_str(response: GenerateResponse) -> str:
|
|
592
|
+
return response.output
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
GlobalTypeConverter.register_converter(_generate_response_to_str)
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
def _generate_response_to_chat_response(response: GenerateResponse) -> ChatResponse:
|
|
599
|
+
data = response.output
|
|
600
|
+
|
|
601
|
+
# Simulate usage
|
|
602
|
+
prompt_tokens = 0
|
|
603
|
+
usage = Usage(prompt_tokens=prompt_tokens,
|
|
604
|
+
completion_tokens=len(data.split()),
|
|
605
|
+
total_tokens=prompt_tokens + len(data.split()))
|
|
606
|
+
|
|
607
|
+
# Build and return the response
|
|
608
|
+
return ChatResponse.from_string(data, usage=usage)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
GlobalTypeConverter.register_converter(_generate_response_to_chat_response)
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
# ======== ChatRequest Converters ========
|
|
615
|
+
def _nat_chat_request_to_string(data: ChatRequest) -> str:
|
|
616
|
+
if isinstance(data.messages[-1].content, str):
|
|
617
|
+
return data.messages[-1].content
|
|
618
|
+
return str(data.messages[-1].content)
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
GlobalTypeConverter.register_converter(_nat_chat_request_to_string)
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
def _string_to_nat_chat_request(data: str) -> ChatRequest:
|
|
625
|
+
return ChatRequest.from_string(data, model="")
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
GlobalTypeConverter.register_converter(_string_to_nat_chat_request)
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
# ======== ChatResponse Converters ========
|
|
632
|
+
def _nat_chat_response_to_string(data: ChatResponse) -> str:
|
|
633
|
+
if data.choices and data.choices[0].message:
|
|
634
|
+
return data.choices[0].message.content or ""
|
|
635
|
+
return ""
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
GlobalTypeConverter.register_converter(_nat_chat_response_to_string)
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
def _string_to_nat_chat_response(data: str) -> ChatResponse:
|
|
642
|
+
'''Converts a string to an ChatResponse object'''
|
|
643
|
+
|
|
644
|
+
# Simulate usage
|
|
645
|
+
prompt_tokens = 0
|
|
646
|
+
usage = Usage(prompt_tokens=prompt_tokens,
|
|
647
|
+
completion_tokens=len(data.split()),
|
|
648
|
+
total_tokens=prompt_tokens + len(data.split()))
|
|
649
|
+
|
|
650
|
+
# Build and return the response
|
|
651
|
+
return ChatResponse.from_string(data, usage=usage)
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
GlobalTypeConverter.register_converter(_string_to_nat_chat_response)
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
def _chat_response_to_chat_response_chunk(data: ChatResponse) -> ChatResponseChunk:
|
|
658
|
+
# Preserve original message structure for backward compatibility
|
|
659
|
+
return ChatResponseChunk(id=data.id, choices=data.choices, created=data.created, model=data.model)
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
GlobalTypeConverter.register_converter(_chat_response_to_chat_response_chunk)
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
# ======== ChatResponseChunk Converters ========
|
|
666
|
+
def _chat_response_chunk_to_string(data: ChatResponseChunk) -> str:
|
|
667
|
+
if data.choices and len(data.choices) > 0:
|
|
668
|
+
choice = data.choices[0]
|
|
669
|
+
if choice.delta and choice.delta.content:
|
|
670
|
+
return choice.delta.content
|
|
671
|
+
if choice.message and choice.message.content:
|
|
672
|
+
return choice.message.content
|
|
673
|
+
return ""
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
GlobalTypeConverter.register_converter(_chat_response_chunk_to_string)
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
def _string_to_nat_chat_response_chunk(data: str) -> ChatResponseChunk:
|
|
680
|
+
'''Converts a string to an ChatResponseChunk object'''
|
|
681
|
+
|
|
682
|
+
# Build and return the response
|
|
683
|
+
return ChatResponseChunk.from_string(data)
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
GlobalTypeConverter.register_converter(_string_to_nat_chat_response_chunk)
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
# ======== AINodeMessageChunk Converters ========
|
|
690
|
+
def _ai_message_chunk_to_nat_chat_response_chunk(data) -> ChatResponseChunk:
|
|
691
|
+
'''Converts LangChain AINodeMessageChunk to ChatResponseChunk'''
|
|
692
|
+
content = ""
|
|
693
|
+
if hasattr(data, 'content') and data.content is not None:
|
|
694
|
+
content = str(data.content)
|
|
695
|
+
elif hasattr(data, 'text') and data.text is not None:
|
|
696
|
+
content = str(data.text)
|
|
697
|
+
elif hasattr(data, 'message') and data.message is not None:
|
|
698
|
+
content = str(data.message)
|
|
699
|
+
|
|
700
|
+
return ChatResponseChunk.create_streaming_chunk(content=content, role="assistant", finish_reason=None)
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
# Compatibility aliases with previous releases
|
|
704
|
+
AIQChatRequest = ChatRequest
|
|
705
|
+
AIQChoiceMessage = ChoiceMessage
|
|
706
|
+
AIQChoiceDelta = ChoiceDelta
|
|
707
|
+
AIQChoice = Choice
|
|
708
|
+
AIQUsage = Usage
|
|
709
|
+
AIQResponseSerializable = ResponseSerializable
|
|
710
|
+
AIQResponseBaseModelOutput = ResponseBaseModelOutput
|
|
711
|
+
AIQResponseBaseModelIntermediate = ResponseBaseModelIntermediate
|
|
712
|
+
AIQChatResponse = ChatResponse
|
|
713
|
+
AIQChatResponseChunk = ChatResponseChunk
|
|
714
|
+
AIQResponseIntermediateStep = ResponseIntermediateStep
|
|
715
|
+
AIQResponsePayloadOutput = ResponsePayloadOutput
|
|
716
|
+
AIQGenerateResponse = GenerateResponse
|