nvidia-nat 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +256 -0
- nat/agent/dual_node.py +67 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +363 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +149 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +225 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +415 -0
- nat/agent/rewoo_agent/prompt.py +110 -0
- nat/agent/rewoo_agent/register.py +157 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +119 -0
- nat/agent/tool_calling_agent/register.py +106 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +93 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +21 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +285 -0
- nat/builder/component_utils.py +316 -0
- nat/builder/context.py +270 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +161 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +24 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +344 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +627 -0
- nat/builder/intermediate_step_manager.py +174 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +148 -0
- nat/builder/workflow_builder.py +1117 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +37 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +129 -0
- nat/cli/commands/info/list_mcp.py +304 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +155 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +246 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- nat/cli/commands/workflow/templates/register.py.j2 +5 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +317 -0
- nat/cli/entrypoint.py +135 -0
- nat/cli/main.py +57 -0
- nat/cli/register_workflow.py +488 -0
- nat/cli/type_registry.py +1000 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/api_server.py +716 -0
- nat/data_models/authentication.py +231 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +58 -0
- nat/data_models/component_ref.py +168 -0
- nat/data_models/config.py +410 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +127 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +30 -0
- nat/data_models/function_dependencies.py +72 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +190 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +43 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +60 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +367 -0
- nat/eval/evaluate.py +510 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +45 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +23 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +184 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +134 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +66 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +36 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +233 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +96 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +183 -0
- nat/front_ends/fastapi/main.py +72 -0
- nat/front_ends/fastapi/message_handler.py +320 -0
- nat/front_ends/fastapi/message_validator.py +352 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/mcp_front_end_config.py +36 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +241 -0
- nat/front_ends/register.py +22 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +57 -0
- nat/llm/nim_llm.py +46 -0
- nat/llm/openai_llm.py +46 -0
- nat/llm/register.py +23 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +94 -0
- nat/llm/utils/error.py +17 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +20 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +322 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +288 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/type_introspection_mixin.py +183 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +310 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +71 -0
- nat/observability/register.py +96 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +627 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +290 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +131 -0
- nat/profiler/decorators/function_tracking.py +254 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/profile_runner.py +473 -0
- nat/profiler/utils.py +184 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +571 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +251 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +21 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +237 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +22 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +195 -0
- nat/runtime/session.py +162 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +318 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +74 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +42 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools/__init__.py +0 -0
- nat/tool/github_tools/create_github_commit.py +133 -0
- nat/tool/github_tools/create_github_issue.py +87 -0
- nat/tool/github_tools/create_github_pr.py +106 -0
- nat/tool/github_tools/get_github_file.py +106 -0
- nat/tool/github_tools/get_github_issue.py +166 -0
- nat/tool/github_tools/get_github_pr.py +256 -0
- nat/tool/github_tools/update_github_issue.py +100 -0
- nat/tool/mcp/__init__.py +14 -0
- nat/tool/mcp/exceptions.py +142 -0
- nat/tool/mcp/mcp_client.py +255 -0
- nat/tool/mcp/mcp_tool.py +96 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +67 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +38 -0
- nat/tool/retriever.py +94 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +289 -0
- nat/utils/exception_handlers/mcp.py +211 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +197 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +290 -0
- nat/utils/type_utils.py +484 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0.dist-info/METADATA +365 -0
- nvidia_nat-1.2.0.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import sys
|
|
18
|
+
import typing
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel
|
|
21
|
+
from pydantic import ConfigDict
|
|
22
|
+
from pydantic import Discriminator
|
|
23
|
+
from pydantic import ValidationError
|
|
24
|
+
from pydantic import ValidationInfo
|
|
25
|
+
from pydantic import ValidatorFunctionWrapHandler
|
|
26
|
+
from pydantic import field_validator
|
|
27
|
+
|
|
28
|
+
from nat.data_models.evaluate import EvalConfig
|
|
29
|
+
from nat.data_models.front_end import FrontEndBaseConfig
|
|
30
|
+
from nat.data_models.function import EmptyFunctionConfig
|
|
31
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
32
|
+
from nat.data_models.logging import LoggingBaseConfig
|
|
33
|
+
from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
|
|
34
|
+
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
35
|
+
from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
|
|
36
|
+
|
|
37
|
+
from .authentication import AuthProviderBaseConfig
|
|
38
|
+
from .common import HashableBaseModel
|
|
39
|
+
from .common import TypedBaseModel
|
|
40
|
+
from .embedder import EmbedderBaseConfig
|
|
41
|
+
from .llm import LLMBaseConfig
|
|
42
|
+
from .memory import MemoryBaseConfig
|
|
43
|
+
from .object_store import ObjectStoreBaseConfig
|
|
44
|
+
from .retriever import RetrieverBaseConfig
|
|
45
|
+
|
|
46
|
+
logger = logging.getLogger(__name__)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
|
50
|
+
from nat.cli.type_registry import GlobalTypeRegistry # pylint: disable=cyclic-import
|
|
51
|
+
|
|
52
|
+
new_errors = []
|
|
53
|
+
logged_once = False
|
|
54
|
+
needs_reraise = False
|
|
55
|
+
for e in err.errors():
|
|
56
|
+
|
|
57
|
+
error_type = e['type']
|
|
58
|
+
if error_type == 'union_tag_invalid' and "ctx" in e and not logged_once:
|
|
59
|
+
requested_type = e["ctx"]["tag"]
|
|
60
|
+
|
|
61
|
+
if (info.field_name in ('workflow', 'functions')):
|
|
62
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_functions()
|
|
63
|
+
elif (info.field_name == "authentication"):
|
|
64
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_auth_providers()
|
|
65
|
+
elif (info.field_name == "llms"):
|
|
66
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_llm_providers()
|
|
67
|
+
elif (info.field_name == "embedders"):
|
|
68
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_embedder_providers()
|
|
69
|
+
elif (info.field_name == "memory"):
|
|
70
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_memorys()
|
|
71
|
+
elif (info.field_name == "object_stores"):
|
|
72
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_object_stores()
|
|
73
|
+
elif (info.field_name == "retrievers"):
|
|
74
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_retriever_providers()
|
|
75
|
+
elif (info.field_name == "tracing"):
|
|
76
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_telemetry_exporters()
|
|
77
|
+
elif (info.field_name == "logging"):
|
|
78
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_logging_method()
|
|
79
|
+
elif (info.field_name == "evaluators"):
|
|
80
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_evaluators()
|
|
81
|
+
elif (info.field_name == "front_ends"):
|
|
82
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_front_ends()
|
|
83
|
+
elif (info.field_name == "ttc_strategies"):
|
|
84
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_ttc_strategies()
|
|
85
|
+
|
|
86
|
+
else:
|
|
87
|
+
assert False, f"Unknown field name {info.field_name} in validator"
|
|
88
|
+
|
|
89
|
+
# Check and see if the there are multiple full types which match this short type
|
|
90
|
+
matching_keys = [k for k in registered_keys if k.local_name == requested_type]
|
|
91
|
+
|
|
92
|
+
assert len(matching_keys) != 1, "Exact match should have been found. Contact developers"
|
|
93
|
+
|
|
94
|
+
matching_key_names = [x.full_type for x in matching_keys]
|
|
95
|
+
registered_key_names = [x.full_type for x in registered_keys]
|
|
96
|
+
|
|
97
|
+
if (len(matching_keys) == 0):
|
|
98
|
+
# This is a case where the requested type is not found. Show a helpful message about what is
|
|
99
|
+
# available
|
|
100
|
+
logger.error(("Requested %s type `%s` not found. "
|
|
101
|
+
"Have you ensured the necessary package has been installed with `uv pip install`?"
|
|
102
|
+
"\nAvailable %s names:\n - %s\n"),
|
|
103
|
+
info.field_name,
|
|
104
|
+
requested_type,
|
|
105
|
+
info.field_name,
|
|
106
|
+
'\n - '.join(registered_key_names))
|
|
107
|
+
else:
|
|
108
|
+
# This is a case where the requested type is ambiguous.
|
|
109
|
+
logger.error(("Requested %s type `%s` is ambiguous. "
|
|
110
|
+
"Matched multiple %s by their local name: %s. "
|
|
111
|
+
"Please use the fully qualified %s name."
|
|
112
|
+
"\nAvailable %s names:\n - %s\n"),
|
|
113
|
+
info.field_name,
|
|
114
|
+
requested_type,
|
|
115
|
+
info.field_name,
|
|
116
|
+
matching_key_names,
|
|
117
|
+
info.field_name,
|
|
118
|
+
info.field_name,
|
|
119
|
+
'\n - '.join(registered_key_names))
|
|
120
|
+
|
|
121
|
+
# Only show one error
|
|
122
|
+
logged_once = True
|
|
123
|
+
|
|
124
|
+
elif error_type == 'missing':
|
|
125
|
+
location = e["loc"]
|
|
126
|
+
if len(location) > 1: # remove the _type field from the location
|
|
127
|
+
e['loc'] = (location[0], ) + location[2:]
|
|
128
|
+
needs_reraise = True
|
|
129
|
+
|
|
130
|
+
new_errors.append(e)
|
|
131
|
+
|
|
132
|
+
if needs_reraise:
|
|
133
|
+
raise ValidationError.from_exception_data(title=err.title, line_errors=new_errors)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class TelemetryConfig(BaseModel):
|
|
137
|
+
|
|
138
|
+
logging: dict[str, LoggingBaseConfig] = {}
|
|
139
|
+
tracing: dict[str, TelemetryExporterBaseConfig] = {}
|
|
140
|
+
|
|
141
|
+
@field_validator("logging", "tracing", mode="wrap")
|
|
142
|
+
@classmethod
|
|
143
|
+
def validate_components(cls, value: typing.Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
return handler(value)
|
|
147
|
+
except ValidationError as err:
|
|
148
|
+
_process_validation_error(err, handler, info)
|
|
149
|
+
raise
|
|
150
|
+
|
|
151
|
+
@classmethod
|
|
152
|
+
def rebuild_annotations(cls):
|
|
153
|
+
|
|
154
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
155
|
+
|
|
156
|
+
type_registry = GlobalTypeRegistry.get()
|
|
157
|
+
|
|
158
|
+
TracingAnnotation = dict[str,
|
|
159
|
+
typing.Annotated[type_registry.compute_annotation(TelemetryExporterBaseConfig),
|
|
160
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
161
|
+
|
|
162
|
+
LoggingAnnotation = dict[str,
|
|
163
|
+
typing.Annotated[type_registry.compute_annotation(LoggingBaseConfig),
|
|
164
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
165
|
+
|
|
166
|
+
should_rebuild = False
|
|
167
|
+
|
|
168
|
+
tracing_field = cls.model_fields.get("tracing")
|
|
169
|
+
if tracing_field is not None and tracing_field.annotation != TracingAnnotation:
|
|
170
|
+
tracing_field.annotation = TracingAnnotation
|
|
171
|
+
should_rebuild = True
|
|
172
|
+
|
|
173
|
+
logging_field = cls.model_fields.get("logging")
|
|
174
|
+
if logging_field is not None and logging_field.annotation != LoggingAnnotation:
|
|
175
|
+
logging_field.annotation = LoggingAnnotation
|
|
176
|
+
should_rebuild = True
|
|
177
|
+
|
|
178
|
+
if (should_rebuild):
|
|
179
|
+
return cls.model_rebuild(force=True)
|
|
180
|
+
|
|
181
|
+
return False
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class GeneralConfig(BaseModel):
|
|
185
|
+
|
|
186
|
+
model_config = ConfigDict(protected_namespaces=())
|
|
187
|
+
|
|
188
|
+
use_uvloop: bool = True
|
|
189
|
+
"""
|
|
190
|
+
Whether to use uvloop for the event loop. This can provide a significant speedup in some cases. Disable to provide
|
|
191
|
+
better error messages when debugging.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
telemetry: TelemetryConfig = TelemetryConfig()
|
|
195
|
+
|
|
196
|
+
# FrontEnd Configuration
|
|
197
|
+
front_end: FrontEndBaseConfig = FastApiFrontEndConfig()
|
|
198
|
+
|
|
199
|
+
@field_validator("front_end", mode="wrap")
|
|
200
|
+
@classmethod
|
|
201
|
+
def validate_components(cls, value: typing.Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
|
202
|
+
|
|
203
|
+
try:
|
|
204
|
+
return handler(value)
|
|
205
|
+
except ValidationError as err:
|
|
206
|
+
_process_validation_error(err, handler, info)
|
|
207
|
+
raise
|
|
208
|
+
|
|
209
|
+
@classmethod
|
|
210
|
+
def rebuild_annotations(cls):
|
|
211
|
+
|
|
212
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
213
|
+
|
|
214
|
+
type_registry = GlobalTypeRegistry.get()
|
|
215
|
+
|
|
216
|
+
FrontEndAnnotation = typing.Annotated[type_registry.compute_annotation(FrontEndBaseConfig),
|
|
217
|
+
Discriminator(TypedBaseModel.discriminator)]
|
|
218
|
+
|
|
219
|
+
should_rebuild = False
|
|
220
|
+
|
|
221
|
+
front_end_field = cls.model_fields.get("front_end")
|
|
222
|
+
if front_end_field is not None and front_end_field.annotation != FrontEndAnnotation:
|
|
223
|
+
front_end_field.annotation = FrontEndAnnotation
|
|
224
|
+
should_rebuild = True
|
|
225
|
+
|
|
226
|
+
if (TelemetryConfig.rebuild_annotations()):
|
|
227
|
+
should_rebuild = True
|
|
228
|
+
|
|
229
|
+
if (should_rebuild):
|
|
230
|
+
return cls.model_rebuild(force=True)
|
|
231
|
+
|
|
232
|
+
return False
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class Config(HashableBaseModel):
|
|
236
|
+
|
|
237
|
+
model_config = ConfigDict(extra="forbid")
|
|
238
|
+
|
|
239
|
+
# Global Options
|
|
240
|
+
general: GeneralConfig = GeneralConfig()
|
|
241
|
+
|
|
242
|
+
# Functions Configuration
|
|
243
|
+
functions: dict[str, FunctionBaseConfig] = {}
|
|
244
|
+
|
|
245
|
+
# LLMs Configuration
|
|
246
|
+
llms: dict[str, LLMBaseConfig] = {}
|
|
247
|
+
|
|
248
|
+
# Embedders Configuration
|
|
249
|
+
embedders: dict[str, EmbedderBaseConfig] = {}
|
|
250
|
+
|
|
251
|
+
# Memory Configuration
|
|
252
|
+
memory: dict[str, MemoryBaseConfig] = {}
|
|
253
|
+
|
|
254
|
+
# Object Stores Configuration
|
|
255
|
+
object_stores: dict[str, ObjectStoreBaseConfig] = {}
|
|
256
|
+
|
|
257
|
+
# Retriever Configuration
|
|
258
|
+
retrievers: dict[str, RetrieverBaseConfig] = {}
|
|
259
|
+
|
|
260
|
+
# TTC Strategies
|
|
261
|
+
ttc_strategies: dict[str, TTCStrategyBaseConfig] = {}
|
|
262
|
+
|
|
263
|
+
# Workflow Configuration
|
|
264
|
+
workflow: FunctionBaseConfig = EmptyFunctionConfig()
|
|
265
|
+
|
|
266
|
+
# Authentication Configuration
|
|
267
|
+
authentication: dict[str, AuthProviderBaseConfig] = {}
|
|
268
|
+
|
|
269
|
+
# Evaluation Options
|
|
270
|
+
eval: EvalConfig = EvalConfig()
|
|
271
|
+
|
|
272
|
+
def print_summary(self, stream: typing.TextIO = sys.stdout):
|
|
273
|
+
"""Print a summary of the configuration"""
|
|
274
|
+
|
|
275
|
+
stream.write("\nConfiguration Summary:\n")
|
|
276
|
+
stream.write("-" * 20 + "\n")
|
|
277
|
+
if self.workflow:
|
|
278
|
+
stream.write(f"Workflow Type: {self.workflow.type}\n")
|
|
279
|
+
|
|
280
|
+
stream.write(f"Number of Functions: {len(self.functions)}\n")
|
|
281
|
+
stream.write(f"Number of LLMs: {len(self.llms)}\n")
|
|
282
|
+
stream.write(f"Number of Embedders: {len(self.embedders)}\n")
|
|
283
|
+
stream.write(f"Number of Memory: {len(self.memory)}\n")
|
|
284
|
+
stream.write(f"Number of Object Stores: {len(self.object_stores)}\n")
|
|
285
|
+
stream.write(f"Number of Retrievers: {len(self.retrievers)}\n")
|
|
286
|
+
stream.write(f"Number of TTC Strategies: {len(self.ttc_strategies)}\n")
|
|
287
|
+
stream.write(f"Number of Authentication Providers: {len(self.authentication)}\n")
|
|
288
|
+
|
|
289
|
+
@field_validator("functions",
|
|
290
|
+
"llms",
|
|
291
|
+
"embedders",
|
|
292
|
+
"memory",
|
|
293
|
+
"retrievers",
|
|
294
|
+
"workflow",
|
|
295
|
+
"ttc_strategies",
|
|
296
|
+
"authentication",
|
|
297
|
+
mode="wrap")
|
|
298
|
+
@classmethod
|
|
299
|
+
def validate_components(cls, value: typing.Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
|
300
|
+
|
|
301
|
+
try:
|
|
302
|
+
return handler(value)
|
|
303
|
+
except ValidationError as err:
|
|
304
|
+
_process_validation_error(err, handler, info)
|
|
305
|
+
raise
|
|
306
|
+
|
|
307
|
+
@classmethod
|
|
308
|
+
def rebuild_annotations(cls):
|
|
309
|
+
|
|
310
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
311
|
+
|
|
312
|
+
type_registry = GlobalTypeRegistry.get()
|
|
313
|
+
|
|
314
|
+
LLMsAnnotation = dict[str,
|
|
315
|
+
typing.Annotated[type_registry.compute_annotation(LLMBaseConfig),
|
|
316
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
317
|
+
|
|
318
|
+
AuthenticationProviderAnnotation = dict[str,
|
|
319
|
+
typing.Annotated[
|
|
320
|
+
type_registry.compute_annotation(AuthProviderBaseConfig),
|
|
321
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
322
|
+
|
|
323
|
+
EmbeddersAnnotation = dict[str,
|
|
324
|
+
typing.Annotated[type_registry.compute_annotation(EmbedderBaseConfig),
|
|
325
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
326
|
+
|
|
327
|
+
FunctionsAnnotation = dict[str,
|
|
328
|
+
typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
|
|
329
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
330
|
+
|
|
331
|
+
MemoryAnnotation = dict[str,
|
|
332
|
+
typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
|
|
333
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
334
|
+
|
|
335
|
+
ObjectStoreAnnotation = dict[str,
|
|
336
|
+
typing.Annotated[type_registry.compute_annotation(ObjectStoreBaseConfig),
|
|
337
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
338
|
+
|
|
339
|
+
RetrieverAnnotation = dict[str,
|
|
340
|
+
typing.Annotated[type_registry.compute_annotation(RetrieverBaseConfig),
|
|
341
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
342
|
+
|
|
343
|
+
TTCStrategyAnnotation = dict[str,
|
|
344
|
+
typing.Annotated[type_registry.compute_annotation(TTCStrategyBaseConfig),
|
|
345
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
346
|
+
|
|
347
|
+
WorkflowAnnotation = typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
|
|
348
|
+
Discriminator(TypedBaseModel.discriminator)]
|
|
349
|
+
|
|
350
|
+
should_rebuild = False
|
|
351
|
+
|
|
352
|
+
auth_providers_field = cls.model_fields.get("authentication")
|
|
353
|
+
if auth_providers_field is not None and auth_providers_field.annotation != AuthenticationProviderAnnotation:
|
|
354
|
+
auth_providers_field.annotation = AuthenticationProviderAnnotation
|
|
355
|
+
should_rebuild = True
|
|
356
|
+
|
|
357
|
+
llms_field = cls.model_fields.get("llms")
|
|
358
|
+
if llms_field is not None and llms_field.annotation != LLMsAnnotation:
|
|
359
|
+
llms_field.annotation = LLMsAnnotation
|
|
360
|
+
should_rebuild = True
|
|
361
|
+
|
|
362
|
+
embedders_field = cls.model_fields.get("embedders")
|
|
363
|
+
if embedders_field is not None and embedders_field.annotation != EmbeddersAnnotation:
|
|
364
|
+
embedders_field.annotation = EmbeddersAnnotation
|
|
365
|
+
should_rebuild = True
|
|
366
|
+
|
|
367
|
+
functions_field = cls.model_fields.get("functions")
|
|
368
|
+
if functions_field is not None and functions_field.annotation != FunctionsAnnotation:
|
|
369
|
+
functions_field.annotation = FunctionsAnnotation
|
|
370
|
+
should_rebuild = True
|
|
371
|
+
|
|
372
|
+
memory_field = cls.model_fields.get("memory")
|
|
373
|
+
if memory_field is not None and memory_field.annotation != MemoryAnnotation:
|
|
374
|
+
memory_field.annotation = MemoryAnnotation
|
|
375
|
+
should_rebuild = True
|
|
376
|
+
|
|
377
|
+
object_stores_field = cls.model_fields.get("object_stores")
|
|
378
|
+
if object_stores_field is not None and object_stores_field.annotation != ObjectStoreAnnotation:
|
|
379
|
+
object_stores_field.annotation = ObjectStoreAnnotation
|
|
380
|
+
should_rebuild = True
|
|
381
|
+
|
|
382
|
+
retrievers_field = cls.model_fields.get("retrievers")
|
|
383
|
+
if retrievers_field is not None and retrievers_field.annotation != RetrieverAnnotation:
|
|
384
|
+
retrievers_field.annotation = RetrieverAnnotation
|
|
385
|
+
should_rebuild = True
|
|
386
|
+
|
|
387
|
+
ttc_strategies_field = cls.model_fields.get("ttc_strategies")
|
|
388
|
+
if ttc_strategies_field is not None and ttc_strategies_field.annotation != TTCStrategyAnnotation:
|
|
389
|
+
ttc_strategies_field.annotation = TTCStrategyAnnotation
|
|
390
|
+
should_rebuild = True
|
|
391
|
+
|
|
392
|
+
workflow_field = cls.model_fields.get("workflow")
|
|
393
|
+
if workflow_field is not None and workflow_field.annotation != WorkflowAnnotation:
|
|
394
|
+
workflow_field.annotation = WorkflowAnnotation
|
|
395
|
+
should_rebuild = True
|
|
396
|
+
|
|
397
|
+
if (GeneralConfig.rebuild_annotations()):
|
|
398
|
+
should_rebuild = True
|
|
399
|
+
|
|
400
|
+
if (EvalConfig.rebuild_annotations()):
|
|
401
|
+
should_rebuild = True
|
|
402
|
+
|
|
403
|
+
if (should_rebuild):
|
|
404
|
+
return cls.model_rebuild(force=True)
|
|
405
|
+
|
|
406
|
+
return False
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
# Compatibility aliases with previous releases
|
|
410
|
+
AIQConfig = Config
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import importlib
|
|
17
|
+
import json
|
|
18
|
+
import typing
|
|
19
|
+
from collections.abc import Callable
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
import pandas as pd
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
from pydantic import Discriminator
|
|
25
|
+
from pydantic import FilePath
|
|
26
|
+
from pydantic import Tag
|
|
27
|
+
|
|
28
|
+
from nat.data_models.common import BaseModelRegistryTag
|
|
29
|
+
from nat.data_models.common import TypedBaseModel
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class EvalS3Config(BaseModel):
|
|
33
|
+
|
|
34
|
+
endpoint_url: str | None = None
|
|
35
|
+
region_name: str | None = None
|
|
36
|
+
bucket: str
|
|
37
|
+
access_key: str
|
|
38
|
+
secret_key: str
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class EvalFilterEntryConfig(BaseModel):
|
|
42
|
+
# values are lists of allowed/blocked values
|
|
43
|
+
field: dict[str, list[str | int | float]] = {}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class EvalFilterConfig(BaseModel):
|
|
47
|
+
allowlist: EvalFilterEntryConfig | None = None
|
|
48
|
+
denylist: EvalFilterEntryConfig | None = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class EvalDatasetStructureConfig(BaseModel):
|
|
52
|
+
disable: bool = False
|
|
53
|
+
question_key: str = "question"
|
|
54
|
+
answer_key: str = "answer"
|
|
55
|
+
generated_answer_key: str = "generated_answer"
|
|
56
|
+
trajectory_key: str = "intermediate_steps"
|
|
57
|
+
expected_trajectory_key: str = "expected_intermediate_steps"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# Base model
|
|
61
|
+
class EvalDatasetBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
62
|
+
|
|
63
|
+
id_key: str = "id"
|
|
64
|
+
structure: EvalDatasetStructureConfig = EvalDatasetStructureConfig()
|
|
65
|
+
|
|
66
|
+
# Filters
|
|
67
|
+
filter: EvalFilterConfig | None = EvalFilterConfig()
|
|
68
|
+
|
|
69
|
+
s3: EvalS3Config | None = None
|
|
70
|
+
|
|
71
|
+
remote_file_path: str | None = None # only for s3
|
|
72
|
+
file_path: Path | str = Path(".tmp/nat/examples/default/default.json")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class EvalDatasetJsonConfig(EvalDatasetBaseConfig, name="json"):
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def parser() -> tuple[Callable, dict]:
|
|
79
|
+
return pd.read_json, {}
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def read_jsonl(file_path: FilePath):
|
|
83
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
|
84
|
+
data = [json.loads(line) for line in f]
|
|
85
|
+
return pd.DataFrame(data)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class EvalDatasetJsonlConfig(EvalDatasetBaseConfig, name="jsonl"):
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def parser() -> tuple[Callable, dict]:
|
|
92
|
+
return read_jsonl, {}
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class EvalDatasetCsvConfig(EvalDatasetBaseConfig, name="csv"):
|
|
96
|
+
|
|
97
|
+
@staticmethod
|
|
98
|
+
def parser() -> tuple[Callable, dict]:
|
|
99
|
+
return pd.read_csv, {}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class EvalDatasetParquetConfig(EvalDatasetBaseConfig, name="parquet"):
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def parser() -> tuple[Callable, dict]:
|
|
106
|
+
return pd.read_parquet, {}
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class EvalDatasetXlsConfig(EvalDatasetBaseConfig, name="xls"):
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def parser() -> tuple[Callable, dict]:
|
|
113
|
+
return pd.read_excel, {"engine": "openpyxl"}
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class EvalDatasetCustomConfig(EvalDatasetBaseConfig, name="custom"):
|
|
117
|
+
"""
|
|
118
|
+
Configuration for custom dataset type that allows users to specify
|
|
119
|
+
a custom Python function to transform their dataset into EvalInput format.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
function: str # Direct import path to function, format: "module.path.function_name"
|
|
123
|
+
kwargs: dict[str, typing.Any] = {} # Additional arguments to pass to the custom function
|
|
124
|
+
|
|
125
|
+
def parser(self) -> tuple[Callable, dict]:
|
|
126
|
+
"""
|
|
127
|
+
Load and return the custom function for dataset transformation.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Tuple of (custom_function, kwargs) where custom_function transforms
|
|
131
|
+
a dataset file into an EvalInput object.
|
|
132
|
+
"""
|
|
133
|
+
custom_function = self._load_custom_function()
|
|
134
|
+
return custom_function, self.kwargs
|
|
135
|
+
|
|
136
|
+
def _load_custom_function(self) -> Callable:
|
|
137
|
+
"""
|
|
138
|
+
Import and return the custom function using standard Python import path.
|
|
139
|
+
"""
|
|
140
|
+
if not self.function:
|
|
141
|
+
raise ValueError("Function path cannot be empty")
|
|
142
|
+
|
|
143
|
+
# Split the function path to get module and function name
|
|
144
|
+
module_path, function_name = self.function.rsplit(".", 1)
|
|
145
|
+
|
|
146
|
+
# Import the module
|
|
147
|
+
module = importlib.import_module(module_path)
|
|
148
|
+
|
|
149
|
+
# Get the function from the module
|
|
150
|
+
if not hasattr(module, function_name):
|
|
151
|
+
raise AttributeError(f"Function '{function_name}' not found in module '{module_path}'")
|
|
152
|
+
|
|
153
|
+
custom_function = getattr(module, function_name)
|
|
154
|
+
|
|
155
|
+
if not callable(custom_function):
|
|
156
|
+
raise ValueError(f"'{self.function}' is not callable")
|
|
157
|
+
|
|
158
|
+
return custom_function
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# Union model with discriminator
|
|
162
|
+
EvalDatasetConfig = typing.Annotated[
|
|
163
|
+
typing.Annotated[EvalDatasetJsonConfig, Tag(EvalDatasetJsonConfig.static_type())]
|
|
164
|
+
| typing.Annotated[EvalDatasetCsvConfig, Tag(EvalDatasetCsvConfig.static_type())]
|
|
165
|
+
| typing.Annotated[EvalDatasetXlsConfig, Tag(EvalDatasetXlsConfig.static_type())]
|
|
166
|
+
| typing.Annotated[EvalDatasetParquetConfig, Tag(EvalDatasetParquetConfig.static_type())]
|
|
167
|
+
| typing.Annotated[EvalDatasetJsonlConfig, Tag(EvalDatasetJsonlConfig.static_type())]
|
|
168
|
+
| typing.Annotated[EvalDatasetCustomConfig, Tag(EvalDatasetCustomConfig.static_type())],
|
|
169
|
+
Discriminator(TypedBaseModel.discriminator)]
|