nvidia-nat 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +256 -0
- nat/agent/dual_node.py +67 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +363 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +149 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +225 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +415 -0
- nat/agent/rewoo_agent/prompt.py +110 -0
- nat/agent/rewoo_agent/register.py +157 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +119 -0
- nat/agent/tool_calling_agent/register.py +106 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +93 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +21 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +285 -0
- nat/builder/component_utils.py +316 -0
- nat/builder/context.py +270 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +161 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +24 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +344 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +627 -0
- nat/builder/intermediate_step_manager.py +174 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +148 -0
- nat/builder/workflow_builder.py +1117 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +37 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +129 -0
- nat/cli/commands/info/list_mcp.py +304 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +155 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +246 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- nat/cli/commands/workflow/templates/register.py.j2 +5 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +317 -0
- nat/cli/entrypoint.py +135 -0
- nat/cli/main.py +57 -0
- nat/cli/register_workflow.py +488 -0
- nat/cli/type_registry.py +1000 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/api_server.py +716 -0
- nat/data_models/authentication.py +231 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +58 -0
- nat/data_models/component_ref.py +168 -0
- nat/data_models/config.py +410 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +127 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +30 -0
- nat/data_models/function_dependencies.py +72 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +190 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +43 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +60 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +367 -0
- nat/eval/evaluate.py +510 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +45 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +23 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +184 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +134 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +66 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +36 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +233 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +96 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +183 -0
- nat/front_ends/fastapi/main.py +72 -0
- nat/front_ends/fastapi/message_handler.py +320 -0
- nat/front_ends/fastapi/message_validator.py +352 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/mcp_front_end_config.py +36 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +241 -0
- nat/front_ends/register.py +22 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +57 -0
- nat/llm/nim_llm.py +46 -0
- nat/llm/openai_llm.py +46 -0
- nat/llm/register.py +23 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +94 -0
- nat/llm/utils/error.py +17 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +20 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +322 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +288 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/type_introspection_mixin.py +183 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +310 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +71 -0
- nat/observability/register.py +96 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +627 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +290 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +131 -0
- nat/profiler/decorators/function_tracking.py +254 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/profile_runner.py +473 -0
- nat/profiler/utils.py +184 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +571 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +251 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +21 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +237 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +22 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +195 -0
- nat/runtime/session.py +162 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +318 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +74 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +42 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools/__init__.py +0 -0
- nat/tool/github_tools/create_github_commit.py +133 -0
- nat/tool/github_tools/create_github_issue.py +87 -0
- nat/tool/github_tools/create_github_pr.py +106 -0
- nat/tool/github_tools/get_github_file.py +106 -0
- nat/tool/github_tools/get_github_issue.py +166 -0
- nat/tool/github_tools/get_github_pr.py +256 -0
- nat/tool/github_tools/update_github_issue.py +100 -0
- nat/tool/mcp/__init__.py +14 -0
- nat/tool/mcp/exceptions.py +142 -0
- nat/tool/mcp/mcp_client.py +255 -0
- nat/tool/mcp/mcp_tool.py +96 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +67 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +38 -0
- nat/tool/retriever.py +94 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +289 -0
- nat/utils/exception_handlers/mcp.py +211 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +197 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +290 -0
- nat/utils/type_utils.py +484 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0.dist-info/METADATA +365 -0
- nvidia_nat-1.2.0.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,627 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import dataclasses
|
|
17
|
+
import inspect
|
|
18
|
+
import logging
|
|
19
|
+
import typing
|
|
20
|
+
from collections.abc import AsyncGenerator
|
|
21
|
+
from collections.abc import Awaitable
|
|
22
|
+
from collections.abc import Callable
|
|
23
|
+
from collections.abc import Coroutine
|
|
24
|
+
from types import NoneType
|
|
25
|
+
|
|
26
|
+
from pydantic import BaseModel
|
|
27
|
+
from pydantic import ConfigDict
|
|
28
|
+
from pydantic import Field
|
|
29
|
+
from pydantic import create_model
|
|
30
|
+
from pydantic_core import PydanticUndefined
|
|
31
|
+
|
|
32
|
+
from nat.data_models.streaming import Streaming
|
|
33
|
+
from nat.utils.type_utils import DecomposedType
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
P = typing.ParamSpec("P")
|
|
38
|
+
SingleCallableT = Callable[P, Coroutine[None, None, typing.Any]]
|
|
39
|
+
StreamCallableT = Callable[P, AsyncGenerator[typing.Any]]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _get_annotated_type(annotated_type: type) -> type:
|
|
43
|
+
origin = typing.get_origin(annotated_type)
|
|
44
|
+
args = typing.get_args(annotated_type)
|
|
45
|
+
|
|
46
|
+
# If its annotated, the first arg is the type
|
|
47
|
+
if (origin == typing.Annotated):
|
|
48
|
+
return args[0]
|
|
49
|
+
|
|
50
|
+
return annotated_type
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _validate_single_fn(single_fn: SingleCallableT | None) -> tuple[type, type]:
|
|
54
|
+
|
|
55
|
+
if single_fn is None:
|
|
56
|
+
return NoneType, NoneType
|
|
57
|
+
|
|
58
|
+
sig = inspect.signature(single_fn)
|
|
59
|
+
|
|
60
|
+
if len(sig.parameters) != 1:
|
|
61
|
+
raise ValueError("single_fn must have exactly one parameter")
|
|
62
|
+
|
|
63
|
+
if (sig.parameters[list(sig.parameters.keys())[0]].annotation == sig.empty):
|
|
64
|
+
raise ValueError("single_fn must have an input annotation")
|
|
65
|
+
|
|
66
|
+
if sig.return_annotation == sig.empty:
|
|
67
|
+
raise ValueError("single_fn must have a return annotation")
|
|
68
|
+
|
|
69
|
+
if not inspect.iscoroutinefunction(single_fn):
|
|
70
|
+
raise ValueError("single_fn must be a coroutine")
|
|
71
|
+
|
|
72
|
+
type_hints = typing.get_type_hints(single_fn)
|
|
73
|
+
|
|
74
|
+
output_type = type_hints.pop("return")
|
|
75
|
+
|
|
76
|
+
assert len(type_hints) == 1
|
|
77
|
+
|
|
78
|
+
input_type = next(iter(type_hints.values()))
|
|
79
|
+
|
|
80
|
+
return input_type, output_type
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _validate_stream_fn(stream_fn: StreamCallableT | None) -> tuple[type, type]:
|
|
84
|
+
|
|
85
|
+
if stream_fn is None:
|
|
86
|
+
return NoneType, NoneType
|
|
87
|
+
|
|
88
|
+
sig = inspect.signature(stream_fn)
|
|
89
|
+
|
|
90
|
+
if len(sig.parameters) != 1:
|
|
91
|
+
raise ValueError("stream_fn must have exactly one parameter")
|
|
92
|
+
|
|
93
|
+
if sig.return_annotation == sig.empty:
|
|
94
|
+
raise ValueError("stream_fn must have a return annotation")
|
|
95
|
+
|
|
96
|
+
if not inspect.isasyncgenfunction(stream_fn):
|
|
97
|
+
raise ValueError("stream_fn must be an async generator")
|
|
98
|
+
|
|
99
|
+
type_hints = typing.get_type_hints(stream_fn)
|
|
100
|
+
|
|
101
|
+
# AsyncGenerator[OutputType, None]
|
|
102
|
+
async_gen_type = DecomposedType(type_hints.pop("return"))
|
|
103
|
+
|
|
104
|
+
if (not async_gen_type.is_async_generator):
|
|
105
|
+
raise ValueError("stream_fn return value must be annotated as an async generator")
|
|
106
|
+
|
|
107
|
+
# If the output type is annotated, get the actual type
|
|
108
|
+
output_type = async_gen_type.get_async_generator_type().type
|
|
109
|
+
|
|
110
|
+
assert len(type_hints) == 1
|
|
111
|
+
|
|
112
|
+
input_type = next(iter(type_hints.values()))
|
|
113
|
+
|
|
114
|
+
return input_type, output_type
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@dataclasses.dataclass
|
|
118
|
+
class FunctionDescriptor:
|
|
119
|
+
|
|
120
|
+
func: Callable
|
|
121
|
+
|
|
122
|
+
arg_count: int
|
|
123
|
+
|
|
124
|
+
is_coroutine: bool
|
|
125
|
+
"""
|
|
126
|
+
Whether the function is a coroutine or not.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
is_async_gen: bool
|
|
130
|
+
"""
|
|
131
|
+
Whether the function is an async generator or not.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
input_type: type | type[None] | None
|
|
135
|
+
"""
|
|
136
|
+
The direct annotated input type to the function. If the function has multiple arguments, this will be a tuple of
|
|
137
|
+
the annotated types. If the function has no annotations, this will be None. If the function has no arguments, this
|
|
138
|
+
will be NoneType.
|
|
139
|
+
"""
|
|
140
|
+
input_schema: type[BaseModel] | type[None] | None
|
|
141
|
+
"""
|
|
142
|
+
The Pydantic schema for the input to the function. This will always be a Pydantic model with the arguments as fields
|
|
143
|
+
( even if the function only has one BaseModel input argument). If the function has no input, this will be NoneType.
|
|
144
|
+
If the function has no annotations, this will be None.
|
|
145
|
+
"""
|
|
146
|
+
input_type_is_base_model: bool
|
|
147
|
+
"""
|
|
148
|
+
True if the input type is a subclass of BaseModel, False otherwise
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
output_type: type | type[None] | None
|
|
152
|
+
"""
|
|
153
|
+
The direct annotated output type to the function. If the function has no annotations, this will be None. If the
|
|
154
|
+
function has no return type, this will be NoneType.
|
|
155
|
+
"""
|
|
156
|
+
output_schema: type[BaseModel] | type[None] | None
|
|
157
|
+
"""
|
|
158
|
+
The Pydantic schema for the output of the function. If the return type is already a BaseModel, the schema will be
|
|
159
|
+
the same as the `output_type`. If the function has no return type, this will be NoneType. If the function has no
|
|
160
|
+
annotations, this will be None.
|
|
161
|
+
"""
|
|
162
|
+
output_type_is_base_model: bool
|
|
163
|
+
"""
|
|
164
|
+
True if the output type is a subclass of BaseModel, False otherwise
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
is_input_typed: bool
|
|
168
|
+
"""
|
|
169
|
+
True if all of the functions input arguments have type annotations, False otherwise
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
is_output_typed: bool
|
|
173
|
+
"""
|
|
174
|
+
True if the function has a return type annotation, False otherwise
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
converters: list[Callable]
|
|
178
|
+
"""
|
|
179
|
+
A list of converters for converting to/from the function's input/output types. Converters are created when
|
|
180
|
+
determining the output schema of a function.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
def get_base_model_function_input(self) -> type[BaseModel] | type[None] | None:
|
|
184
|
+
"""
|
|
185
|
+
Returns a BaseModel type which can be used as the function input. If the InputType is a BaseModel, it will be
|
|
186
|
+
returned, otherwise the InputSchema will be returned. If the function has no input, NoneType will be returned.
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
if self.input_type_is_base_model:
|
|
190
|
+
return self.input_type
|
|
191
|
+
|
|
192
|
+
return self.input_schema
|
|
193
|
+
|
|
194
|
+
def get_base_model_function_output(self,
|
|
195
|
+
converters: list[Callable] | None = None) -> type[BaseModel] | type[None] | None:
|
|
196
|
+
"""
|
|
197
|
+
Returns a BaseModel type which can be used as the function output. If the OutputType is a BaseModel, it will be
|
|
198
|
+
returned, otherwise the OutputSchema will be returned. If the function has no output, NoneType will be returned.
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
if (converters is not None):
|
|
202
|
+
converters.extend(self.converters)
|
|
203
|
+
|
|
204
|
+
if self.output_type_is_base_model:
|
|
205
|
+
return self.output_type
|
|
206
|
+
|
|
207
|
+
return self.output_schema
|
|
208
|
+
|
|
209
|
+
@staticmethod
|
|
210
|
+
def from_function(func: Callable) -> 'FunctionDescriptor':
|
|
211
|
+
|
|
212
|
+
is_coroutine = inspect.iscoroutinefunction(func)
|
|
213
|
+
is_async_gen = inspect.isasyncgenfunction(func)
|
|
214
|
+
|
|
215
|
+
converters = []
|
|
216
|
+
|
|
217
|
+
sig = inspect.signature(func)
|
|
218
|
+
|
|
219
|
+
arg_count = len(sig.parameters)
|
|
220
|
+
|
|
221
|
+
if (arg_count == 0):
|
|
222
|
+
input_type = NoneType
|
|
223
|
+
is_input_typed = False
|
|
224
|
+
input_schema = NoneType
|
|
225
|
+
elif (arg_count == 1):
|
|
226
|
+
first_annotation = sig.parameters[list(sig.parameters.keys())[0]].annotation
|
|
227
|
+
|
|
228
|
+
is_input_typed = first_annotation != sig.empty
|
|
229
|
+
|
|
230
|
+
input_type = first_annotation if is_input_typed else None
|
|
231
|
+
else:
|
|
232
|
+
annotations = [param.annotation for param in sig.parameters.values()]
|
|
233
|
+
|
|
234
|
+
is_input_typed = all([a != sig.empty for a in annotations]) # pylint: disable=use-a-generator
|
|
235
|
+
|
|
236
|
+
input_type = tuple[*annotations] if is_input_typed else None # noqa: syntax-error
|
|
237
|
+
|
|
238
|
+
# Get the base type here removing all annotations and async generators
|
|
239
|
+
output_annotation_decomp = DecomposedType(sig.return_annotation).get_base_type()
|
|
240
|
+
|
|
241
|
+
is_output_typed = not output_annotation_decomp.is_empty
|
|
242
|
+
|
|
243
|
+
output_type = output_annotation_decomp.type if is_output_typed else None
|
|
244
|
+
|
|
245
|
+
output_schema = output_annotation_decomp.get_pydantic_schema(converters) if is_output_typed else None
|
|
246
|
+
|
|
247
|
+
if (input_type is not None):
|
|
248
|
+
|
|
249
|
+
args_schema: dict[str, tuple[type, typing.Any]] = {}
|
|
250
|
+
|
|
251
|
+
for param in sig.parameters.values():
|
|
252
|
+
|
|
253
|
+
default_val = PydanticUndefined
|
|
254
|
+
|
|
255
|
+
if (param.default != sig.empty):
|
|
256
|
+
default_val = param.default
|
|
257
|
+
|
|
258
|
+
args_schema[param.name] = (param.annotation, Field(default=default_val))
|
|
259
|
+
|
|
260
|
+
input_schema = create_model("InputArgsSchema",
|
|
261
|
+
__config__=ConfigDict(arbitrary_types_allowed=True),
|
|
262
|
+
**args_schema)
|
|
263
|
+
else:
|
|
264
|
+
input_schema = None
|
|
265
|
+
|
|
266
|
+
input_type_is_base_model = False
|
|
267
|
+
output_type_is_base_model = False
|
|
268
|
+
|
|
269
|
+
if (input_type is not None):
|
|
270
|
+
input_type_is_base_model = DecomposedType(input_type).is_subtype(BaseModel)
|
|
271
|
+
|
|
272
|
+
if (output_type is not None):
|
|
273
|
+
output_type_is_base_model = DecomposedType(output_type).is_subtype(BaseModel)
|
|
274
|
+
|
|
275
|
+
return FunctionDescriptor(func=func,
|
|
276
|
+
arg_count=arg_count,
|
|
277
|
+
is_coroutine=is_coroutine,
|
|
278
|
+
is_async_gen=is_async_gen,
|
|
279
|
+
is_input_typed=is_input_typed,
|
|
280
|
+
is_output_typed=is_output_typed,
|
|
281
|
+
input_type=input_type,
|
|
282
|
+
output_type=output_type,
|
|
283
|
+
input_schema=input_schema,
|
|
284
|
+
output_schema=output_schema,
|
|
285
|
+
input_type_is_base_model=input_type_is_base_model,
|
|
286
|
+
output_type_is_base_model=output_type_is_base_model,
|
|
287
|
+
converters=converters)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class FunctionInfo:
|
|
291
|
+
|
|
292
|
+
def __init__(self,
|
|
293
|
+
*,
|
|
294
|
+
single_fn: SingleCallableT | None = None,
|
|
295
|
+
stream_fn: StreamCallableT | None = None,
|
|
296
|
+
input_schema: type[BaseModel] | type[None],
|
|
297
|
+
single_output_schema: type[BaseModel] | type[None],
|
|
298
|
+
stream_output_schema: type[BaseModel] | type[None],
|
|
299
|
+
description: str | None = None,
|
|
300
|
+
converters: list[Callable] | None = None):
|
|
301
|
+
self.single_fn = single_fn
|
|
302
|
+
self.stream_fn = stream_fn
|
|
303
|
+
self.input_schema = input_schema
|
|
304
|
+
self.single_output_schema = single_output_schema
|
|
305
|
+
self.stream_output_schema = stream_output_schema
|
|
306
|
+
self.description = description
|
|
307
|
+
self.converters = converters
|
|
308
|
+
|
|
309
|
+
# At this point, we only are validating the passed in information. We are not converting anything. That will
|
|
310
|
+
# be done in the `create()`` and `from_fn()` static methods.
|
|
311
|
+
single_input_type, single_output_type = _validate_single_fn(single_fn)
|
|
312
|
+
stream_input_type, stream_output_type = _validate_stream_fn(stream_fn)
|
|
313
|
+
|
|
314
|
+
if ((NoneType not in (single_input_type, stream_input_type)) and (single_input_type != stream_input_type)):
|
|
315
|
+
raise ValueError("single_fn and stream_fn must have the same input type")
|
|
316
|
+
|
|
317
|
+
if (single_input_type is not NoneType):
|
|
318
|
+
self.input_type = single_input_type
|
|
319
|
+
elif (stream_input_type is not None):
|
|
320
|
+
self.input_type = stream_input_type
|
|
321
|
+
else:
|
|
322
|
+
raise ValueError("At least one of single_fn or stream_fn must be provided")
|
|
323
|
+
|
|
324
|
+
self.single_output_type: type = single_output_type
|
|
325
|
+
self.stream_output_type: type = stream_output_type
|
|
326
|
+
|
|
327
|
+
if (self.single_fn is None and self.stream_fn is None):
|
|
328
|
+
raise ValueError("At least one of single_fn or stream_fn must be provided")
|
|
329
|
+
|
|
330
|
+
# All of the schemas must be provided. NoneType indicates there is no type. None indicates not set
|
|
331
|
+
if (self.input_schema is None):
|
|
332
|
+
raise ValueError("input_schema must be provided")
|
|
333
|
+
|
|
334
|
+
if (self.single_output_schema is None):
|
|
335
|
+
raise ValueError("single_output_schema must be provided. Use NoneType if there is single output")
|
|
336
|
+
|
|
337
|
+
if (self.stream_output_schema is None):
|
|
338
|
+
raise ValueError("stream_output_schema must be provided. Use NoneType if there is stream output")
|
|
339
|
+
|
|
340
|
+
if (self.single_fn and self.single_output_schema == NoneType):
|
|
341
|
+
raise ValueError("single_output_schema must be provided if single_fn is provided")
|
|
342
|
+
if (not self.single_fn and self.single_output_schema != NoneType):
|
|
343
|
+
raise ValueError("single_output_schema must be NoneType if single_fn is not provided")
|
|
344
|
+
|
|
345
|
+
if (self.stream_fn and self.stream_output_schema is NoneType):
|
|
346
|
+
raise ValueError("stream_output_schema must be provided if stream_fn is provided")
|
|
347
|
+
if (not self.stream_fn and self.stream_output_schema != NoneType):
|
|
348
|
+
raise ValueError("stream_output_schema must be NoneType if stream_fn is not provided")
|
|
349
|
+
|
|
350
|
+
@staticmethod
|
|
351
|
+
def create(*,
|
|
352
|
+
single_fn: SingleCallableT | None = None,
|
|
353
|
+
stream_fn: StreamCallableT | None = None,
|
|
354
|
+
input_schema: type[BaseModel] | type[None] | None = None,
|
|
355
|
+
single_output_schema: type[BaseModel] | type[None] | None = None,
|
|
356
|
+
stream_output_schema: type[BaseModel] | type[None] | None = None,
|
|
357
|
+
single_to_stream_fn: Callable[[typing.Any], AsyncGenerator[typing.Any]]
|
|
358
|
+
| None = None,
|
|
359
|
+
stream_to_single_fn: Callable[[AsyncGenerator[typing.Any]], Awaitable[typing.Any]]
|
|
360
|
+
| None = None,
|
|
361
|
+
description: str | None = None,
|
|
362
|
+
converters: list[Callable] | None = None) -> 'FunctionInfo':
|
|
363
|
+
|
|
364
|
+
converters = converters or []
|
|
365
|
+
|
|
366
|
+
final_single_fn: SingleCallableT | None = None
|
|
367
|
+
final_stream_fn: StreamCallableT | None = None
|
|
368
|
+
|
|
369
|
+
# Check the correct combination of functions
|
|
370
|
+
if (single_fn is not None):
|
|
371
|
+
final_single_fn = single_fn
|
|
372
|
+
|
|
373
|
+
if (stream_to_single_fn is not None):
|
|
374
|
+
raise ValueError("Cannot provide both single_fn and stream_to_single_fn")
|
|
375
|
+
else:
|
|
376
|
+
if (stream_to_single_fn is not None and stream_fn is None):
|
|
377
|
+
raise ValueError("stream_fn must be provided if stream_to_single_fn is provided")
|
|
378
|
+
|
|
379
|
+
if (stream_fn is not None):
|
|
380
|
+
final_stream_fn = stream_fn
|
|
381
|
+
|
|
382
|
+
if (single_to_stream_fn is not None):
|
|
383
|
+
raise ValueError("Cannot provide both stream_fn and single_to_stream_fn")
|
|
384
|
+
else:
|
|
385
|
+
if (single_to_stream_fn is not None and single_fn is None):
|
|
386
|
+
raise ValueError("single_fn must be provided if single_to_stream_fn is provided")
|
|
387
|
+
|
|
388
|
+
if (single_fn is None and stream_fn is None):
|
|
389
|
+
raise ValueError("At least one of single_fn or stream_fn must be provided")
|
|
390
|
+
|
|
391
|
+
# Now we know that we have the correct combination of functions. See if we can make conversions
|
|
392
|
+
if (single_to_stream_fn is not None):
|
|
393
|
+
|
|
394
|
+
if (single_fn is None):
|
|
395
|
+
raise ValueError("single_fn must be provided if single_to_stream_fn is provided")
|
|
396
|
+
|
|
397
|
+
single_to_stream_fn_desc = FunctionDescriptor.from_function(single_to_stream_fn)
|
|
398
|
+
|
|
399
|
+
if single_to_stream_fn_desc.arg_count != 1:
|
|
400
|
+
raise ValueError("single_to_stream_fn must have exactly one argument")
|
|
401
|
+
|
|
402
|
+
if not single_to_stream_fn_desc.is_output_typed:
|
|
403
|
+
raise ValueError("single_to_stream_fn must have a return annotation")
|
|
404
|
+
|
|
405
|
+
if not single_to_stream_fn_desc.is_async_gen:
|
|
406
|
+
raise ValueError("single_to_stream_fn must be an async generator")
|
|
407
|
+
|
|
408
|
+
single_fn_desc = FunctionDescriptor.from_function(single_fn)
|
|
409
|
+
|
|
410
|
+
if (single_fn_desc.output_type != single_to_stream_fn_desc.input_type):
|
|
411
|
+
raise ValueError("single_to_stream_fn must have the same input type as the output from single_fn")
|
|
412
|
+
|
|
413
|
+
async def _converted_stream_fn(
|
|
414
|
+
message: single_fn_desc.input_type) -> AsyncGenerator[single_to_stream_fn_desc.output_type]:
|
|
415
|
+
value = await single_fn(message)
|
|
416
|
+
|
|
417
|
+
async for m in single_to_stream_fn(value):
|
|
418
|
+
yield m
|
|
419
|
+
|
|
420
|
+
final_stream_fn = _converted_stream_fn
|
|
421
|
+
|
|
422
|
+
if (stream_to_single_fn is not None):
|
|
423
|
+
|
|
424
|
+
if (stream_fn is None):
|
|
425
|
+
raise ValueError("stream_fn must be provided if stream_to_single_fn is provided")
|
|
426
|
+
|
|
427
|
+
stream_to_single_fn_desc = FunctionDescriptor.from_function(stream_to_single_fn)
|
|
428
|
+
|
|
429
|
+
if stream_to_single_fn_desc.arg_count != 1:
|
|
430
|
+
raise ValueError("stream_to_single_fn must have exactly one parameter")
|
|
431
|
+
|
|
432
|
+
if not stream_to_single_fn_desc.is_output_typed:
|
|
433
|
+
raise ValueError("stream_to_single_fn must have a return annotation")
|
|
434
|
+
|
|
435
|
+
if not stream_to_single_fn_desc.is_coroutine:
|
|
436
|
+
raise ValueError("stream_to_single_fn must be a coroutine")
|
|
437
|
+
|
|
438
|
+
stream_fn_desc = FunctionDescriptor.from_function(stream_fn)
|
|
439
|
+
|
|
440
|
+
if (AsyncGenerator[stream_fn_desc.output_type] != stream_to_single_fn_desc.input_type):
|
|
441
|
+
raise ValueError("stream_to_single_fn must take an async generator with "
|
|
442
|
+
"the same input type as the output from stream_fn")
|
|
443
|
+
|
|
444
|
+
async def _converted_single_fn(message: stream_fn_desc.input_type) -> stream_to_single_fn_desc.output_type:
|
|
445
|
+
|
|
446
|
+
return await stream_to_single_fn(stream_fn(message))
|
|
447
|
+
|
|
448
|
+
final_single_fn = _converted_single_fn
|
|
449
|
+
|
|
450
|
+
# Check the input/output of the functions to make sure they are all BaseModels
|
|
451
|
+
if (final_single_fn is not None):
|
|
452
|
+
|
|
453
|
+
final_single_fn_desc = FunctionDescriptor.from_function(final_single_fn)
|
|
454
|
+
|
|
455
|
+
if (final_single_fn_desc.arg_count > 1):
|
|
456
|
+
if (input_schema is not None):
|
|
457
|
+
logger.warning("Using provided input_schema for multi-argument function")
|
|
458
|
+
else:
|
|
459
|
+
input_schema = final_single_fn_desc.get_base_model_function_input()
|
|
460
|
+
|
|
461
|
+
saved_final_single_fn = final_single_fn
|
|
462
|
+
|
|
463
|
+
async def _convert_input_pydantic(value: input_schema) -> final_single_fn_desc.output_type:
|
|
464
|
+
|
|
465
|
+
# Unpack the pydantic model into the arguments
|
|
466
|
+
return await saved_final_single_fn(**value.model_dump())
|
|
467
|
+
|
|
468
|
+
final_single_fn = _convert_input_pydantic
|
|
469
|
+
|
|
470
|
+
# Reset the descriptor
|
|
471
|
+
final_single_fn_desc = FunctionDescriptor.from_function(final_single_fn)
|
|
472
|
+
|
|
473
|
+
input_schema = input_schema or final_single_fn_desc.get_base_model_function_input()
|
|
474
|
+
|
|
475
|
+
single_output_schema = single_output_schema or final_single_fn_desc.get_base_model_function_output(
|
|
476
|
+
converters)
|
|
477
|
+
|
|
478
|
+
# Check if the final_stream_fn is None. We can use the final_single_fn to create a streaming version
|
|
479
|
+
# automatically
|
|
480
|
+
if (final_stream_fn is None):
|
|
481
|
+
|
|
482
|
+
async def _stream_from_single_fn(
|
|
483
|
+
message: final_single_fn_desc.input_type) -> AsyncGenerator[final_single_fn_desc.output_type]:
|
|
484
|
+
value = await final_single_fn(message)
|
|
485
|
+
|
|
486
|
+
yield value
|
|
487
|
+
|
|
488
|
+
final_stream_fn = _stream_from_single_fn
|
|
489
|
+
|
|
490
|
+
else:
|
|
491
|
+
single_output_schema = NoneType
|
|
492
|
+
|
|
493
|
+
if (final_stream_fn is not None):
|
|
494
|
+
|
|
495
|
+
final_stream_fn_desc = FunctionDescriptor.from_function(final_stream_fn)
|
|
496
|
+
|
|
497
|
+
if (final_stream_fn_desc.arg_count > 1):
|
|
498
|
+
if (input_schema is not None):
|
|
499
|
+
logger.warning("Using provided input_schema for multi-argument function")
|
|
500
|
+
else:
|
|
501
|
+
input_schema = final_stream_fn_desc.get_base_model_function_input()
|
|
502
|
+
|
|
503
|
+
saved_final_stream_fn = final_stream_fn
|
|
504
|
+
|
|
505
|
+
async def _convert_input_pydantic_stream(
|
|
506
|
+
value: input_schema) -> AsyncGenerator[final_stream_fn_desc.output_type]:
|
|
507
|
+
|
|
508
|
+
# Unpack the pydantic model into the arguments
|
|
509
|
+
async for m in saved_final_stream_fn(**value.model_dump()):
|
|
510
|
+
yield m
|
|
511
|
+
|
|
512
|
+
final_stream_fn = _convert_input_pydantic_stream
|
|
513
|
+
|
|
514
|
+
# Reset the descriptor
|
|
515
|
+
final_stream_fn_desc = FunctionDescriptor.from_function(final_stream_fn)
|
|
516
|
+
|
|
517
|
+
input_schema = input_schema or final_stream_fn_desc.get_base_model_function_input()
|
|
518
|
+
|
|
519
|
+
stream_output_schema = stream_output_schema or final_stream_fn_desc.get_base_model_function_output(
|
|
520
|
+
converters)
|
|
521
|
+
else:
|
|
522
|
+
stream_output_schema = NoneType
|
|
523
|
+
|
|
524
|
+
# Do the final check for the input schema from the final functions
|
|
525
|
+
if (input_schema is None):
|
|
526
|
+
|
|
527
|
+
if (final_single_fn):
|
|
528
|
+
|
|
529
|
+
final_single_fn_desc = FunctionDescriptor.from_function(final_single_fn)
|
|
530
|
+
|
|
531
|
+
if (final_single_fn_desc.input_type != NoneType):
|
|
532
|
+
input_schema = final_single_fn_desc.get_base_model_function_output(converters)
|
|
533
|
+
|
|
534
|
+
elif (final_stream_fn):
|
|
535
|
+
|
|
536
|
+
final_stream_fn_desc = FunctionDescriptor.from_function(final_stream_fn)
|
|
537
|
+
|
|
538
|
+
if (final_stream_fn_desc.input_type != NoneType):
|
|
539
|
+
input_schema = final_stream_fn_desc.get_base_model_function_output(converters)
|
|
540
|
+
|
|
541
|
+
else:
|
|
542
|
+
# Cant be None
|
|
543
|
+
input_schema = NoneType
|
|
544
|
+
|
|
545
|
+
return FunctionInfo(single_fn=final_single_fn,
|
|
546
|
+
stream_fn=final_stream_fn,
|
|
547
|
+
input_schema=input_schema,
|
|
548
|
+
single_output_schema=single_output_schema,
|
|
549
|
+
stream_output_schema=stream_output_schema,
|
|
550
|
+
description=description,
|
|
551
|
+
converters=converters)
|
|
552
|
+
|
|
553
|
+
@staticmethod
|
|
554
|
+
def from_fn(fn: SingleCallableT | StreamCallableT,
|
|
555
|
+
*,
|
|
556
|
+
input_schema: type[BaseModel] | None = None,
|
|
557
|
+
description: str | None = None,
|
|
558
|
+
converters: list[Callable] | None = None) -> 'FunctionInfo':
|
|
559
|
+
"""
|
|
560
|
+
Creates a FunctionInfo object from either a single or stream function. Automatically determines the type of
|
|
561
|
+
function and creates the appropriate FunctionInfo object. Supports type annotations for conversion functions.
|
|
562
|
+
|
|
563
|
+
Parameters
|
|
564
|
+
----------
|
|
565
|
+
fn : SingleCallableT | StreamCallableT
|
|
566
|
+
The function to create the FunctionInfo object from
|
|
567
|
+
input_schema : type[BaseModel] | None, optional
|
|
568
|
+
A schema object which defines the input to the function, by default None
|
|
569
|
+
description : str | None, optional
|
|
570
|
+
A description to set to the function, by default None
|
|
571
|
+
converters : list[Callable] | None, optional
|
|
572
|
+
A list of converters for converting to/from the function's input/output types, by default None
|
|
573
|
+
|
|
574
|
+
Returns
|
|
575
|
+
-------
|
|
576
|
+
FunctionInfo
|
|
577
|
+
The created FunctionInfo object which can be used to create a Generic NAT function.
|
|
578
|
+
|
|
579
|
+
"""
|
|
580
|
+
|
|
581
|
+
stream_fn: StreamCallableT | None = None
|
|
582
|
+
single_fn: SingleCallableT | None = None
|
|
583
|
+
|
|
584
|
+
if (inspect.isasyncgenfunction(fn)):
|
|
585
|
+
stream_fn = fn
|
|
586
|
+
|
|
587
|
+
sig = inspect.signature(fn)
|
|
588
|
+
|
|
589
|
+
output_origin = typing.get_origin(sig.return_annotation)
|
|
590
|
+
output_args = typing.get_args(sig.return_annotation)
|
|
591
|
+
|
|
592
|
+
if (output_origin == typing.Annotated):
|
|
593
|
+
# typing.Annotated[AsyncGenerator[OutputType, None], ...]
|
|
594
|
+
annotated_args = output_args[1:]
|
|
595
|
+
|
|
596
|
+
stream_arg = None
|
|
597
|
+
|
|
598
|
+
for arg in annotated_args:
|
|
599
|
+
if (isinstance(arg, Streaming)):
|
|
600
|
+
stream_arg = arg
|
|
601
|
+
break
|
|
602
|
+
|
|
603
|
+
if (stream_arg):
|
|
604
|
+
single_input_type = sig.parameters[list(sig.parameters.keys())[0]].annotation
|
|
605
|
+
single_output_type = stream_arg.single_output_type
|
|
606
|
+
|
|
607
|
+
async def _stream_to_single_output(message: single_input_type) -> single_output_type:
|
|
608
|
+
values = []
|
|
609
|
+
|
|
610
|
+
async for m in stream_fn(message):
|
|
611
|
+
values.append(m)
|
|
612
|
+
|
|
613
|
+
return stream_arg.convert(values)
|
|
614
|
+
|
|
615
|
+
single_fn = _stream_to_single_output
|
|
616
|
+
|
|
617
|
+
elif (inspect.iscoroutinefunction(fn)):
|
|
618
|
+
single_fn = fn
|
|
619
|
+
|
|
620
|
+
else:
|
|
621
|
+
raise ValueError("Invalid workflow function. Must be an async generator or coroutine")
|
|
622
|
+
|
|
623
|
+
return FunctionInfo.create(single_fn=single_fn,
|
|
624
|
+
stream_fn=stream_fn,
|
|
625
|
+
input_schema=input_schema,
|
|
626
|
+
description=description,
|
|
627
|
+
converters=converters or [])
|