nvidia-nat 1.4.0a20251112__py3-none-any.whl → 1.4.0a20260113__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +1 -1
- nat/{front_ends/mcp → agent/auto_memory_wrapper}/__init__.py +1 -1
- nat/agent/auto_memory_wrapper/agent.py +278 -0
- nat/agent/auto_memory_wrapper/register.py +227 -0
- nat/agent/auto_memory_wrapper/state.py +30 -0
- nat/agent/base.py +1 -1
- nat/agent/dual_node.py +1 -1
- nat/agent/prompt_optimizer/prompt.py +1 -1
- nat/agent/prompt_optimizer/register.py +1 -1
- nat/agent/react_agent/agent.py +16 -9
- nat/agent/react_agent/output_parser.py +2 -2
- nat/agent/react_agent/prompt.py +3 -2
- nat/agent/react_agent/register.py +2 -2
- nat/agent/react_agent/register_per_user_agent.py +104 -0
- nat/agent/reasoning_agent/reasoning_agent.py +1 -1
- nat/agent/register.py +3 -1
- nat/agent/responses_api_agent/__init__.py +1 -1
- nat/agent/responses_api_agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +9 -4
- nat/agent/rewoo_agent/prompt.py +1 -1
- nat/agent/rewoo_agent/register.py +1 -1
- nat/agent/tool_calling_agent/agent.py +5 -4
- nat/agent/tool_calling_agent/register.py +1 -1
- nat/authentication/__init__.py +1 -1
- nat/authentication/api_key/__init__.py +1 -1
- nat/authentication/api_key/api_key_auth_provider.py +1 -1
- nat/authentication/api_key/api_key_auth_provider_config.py +22 -7
- nat/authentication/api_key/register.py +1 -1
- nat/authentication/credential_validator/__init__.py +1 -1
- nat/authentication/credential_validator/bearer_token_validator.py +1 -1
- nat/authentication/exceptions/__init__.py +1 -1
- nat/authentication/exceptions/api_key_exceptions.py +1 -1
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/http_basic_auth/register.py +1 -1
- nat/authentication/interfaces.py +1 -1
- nat/authentication/oauth2/__init__.py +1 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +1 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +1 -1
- nat/authentication/oauth2/oauth2_resource_server_config.py +1 -1
- nat/authentication/oauth2/register.py +1 -1
- nat/authentication/register.py +1 -1
- nat/builder/builder.py +563 -1
- nat/builder/child_builder.py +385 -0
- nat/builder/component_utils.py +34 -4
- nat/builder/context.py +34 -1
- nat/builder/embedder.py +1 -1
- nat/builder/eval_builder.py +19 -7
- nat/builder/evaluator.py +1 -1
- nat/builder/framework_enum.py +3 -1
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +113 -5
- nat/builder/function_base.py +1 -1
- nat/builder/function_info.py +1 -1
- nat/builder/intermediate_step_manager.py +1 -1
- nat/builder/llm.py +1 -1
- nat/builder/per_user_workflow_builder.py +843 -0
- nat/builder/retriever.py +1 -1
- nat/builder/sync_builder.py +571 -0
- nat/builder/user_interaction_manager.py +1 -1
- nat/builder/workflow.py +5 -3
- nat/builder/workflow_builder.py +619 -378
- nat/cli/__init__.py +1 -1
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/cli_utils/validation.py +32 -1
- nat/cli/commands/configure/channel/add.py +1 -1
- nat/cli/commands/configure/channel/channel.py +1 -1
- nat/cli/commands/configure/channel/remove.py +1 -1
- nat/cli/commands/configure/channel/update.py +1 -1
- nat/cli/commands/configure/configure.py +1 -1
- nat/cli/commands/evaluate.py +87 -13
- nat/cli/commands/finetune.py +132 -0
- nat/cli/commands/info/__init__.py +1 -1
- nat/cli/commands/info/info.py +1 -1
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +1 -1
- nat/cli/commands/object_store/__init__.py +1 -1
- nat/cli/commands/object_store/object_store.py +1 -1
- nat/cli/commands/optimize.py +1 -1
- nat/cli/commands/{mcp → red_teaming}/__init__.py +1 -1
- nat/cli/commands/red_teaming/red_teaming.py +138 -0
- nat/cli/commands/red_teaming/red_teaming_utils.py +73 -0
- nat/cli/commands/registry/__init__.py +1 -1
- nat/cli/commands/registry/publish.py +1 -1
- nat/cli/commands/registry/pull.py +1 -1
- nat/cli/commands/registry/registry.py +1 -1
- nat/cli/commands/registry/remove.py +1 -1
- nat/cli/commands/registry/search.py +1 -1
- nat/cli/commands/sizing/__init__.py +1 -1
- nat/cli/commands/sizing/calc.py +1 -1
- nat/cli/commands/sizing/sizing.py +1 -1
- nat/cli/commands/start.py +1 -1
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/validate.py +1 -1
- nat/cli/commands/workflow/__init__.py +1 -1
- nat/cli/commands/workflow/workflow.py +1 -1
- nat/cli/commands/workflow/workflow_commands.py +3 -2
- nat/cli/entrypoint.py +15 -37
- nat/cli/main.py +2 -2
- nat/cli/plugin_loader.py +69 -0
- nat/cli/register_workflow.py +233 -5
- nat/cli/type_registry.py +237 -3
- nat/control_flow/register.py +1 -1
- nat/control_flow/router_agent/agent.py +1 -1
- nat/control_flow/router_agent/prompt.py +1 -1
- nat/control_flow/router_agent/register.py +1 -1
- nat/control_flow/sequential_executor.py +28 -7
- nat/data_models/__init__.py +1 -1
- nat/data_models/agent.py +1 -1
- nat/data_models/api_server.py +38 -3
- nat/data_models/authentication.py +1 -1
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +9 -1
- nat/data_models/component_ref.py +45 -1
- nat/data_models/config.py +78 -1
- nat/data_models/dataset_handler.py +15 -2
- nat/data_models/discovery_metadata.py +1 -1
- nat/data_models/embedder.py +1 -1
- nat/data_models/evaluate.py +6 -1
- nat/data_models/evaluator.py +1 -1
- nat/data_models/finetuning.py +260 -0
- nat/data_models/front_end.py +1 -1
- nat/data_models/function.py +15 -2
- nat/data_models/function_dependencies.py +1 -1
- nat/data_models/gated_field_mixin.py +1 -1
- nat/data_models/interactive.py +1 -1
- nat/data_models/intermediate_step.py +29 -2
- nat/data_models/invocation_node.py +1 -1
- nat/data_models/llm.py +1 -1
- nat/data_models/logging.py +1 -1
- nat/data_models/memory.py +1 -1
- nat/data_models/middleware.py +37 -0
- nat/data_models/object_store.py +1 -1
- nat/data_models/openai_mcp.py +1 -1
- nat/data_models/optimizable.py +1 -1
- nat/data_models/optimizer.py +1 -1
- nat/data_models/profiler.py +1 -1
- nat/data_models/registry_handler.py +1 -1
- nat/data_models/retriever.py +1 -1
- nat/data_models/retry_mixin.py +1 -1
- nat/data_models/runtime_enum.py +26 -0
- nat/data_models/span.py +1 -1
- nat/data_models/step_adaptor.py +1 -1
- nat/data_models/streaming.py +1 -1
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/telemetry_exporter.py +1 -1
- nat/data_models/thinking_mixin.py +1 -1
- nat/data_models/ttc_strategy.py +1 -1
- nat/embedder/azure_openai_embedder.py +1 -1
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +1 -1
- nat/eval/__init__.py +1 -1
- nat/eval/config.py +8 -1
- nat/eval/dataset_handler/dataset_downloader.py +1 -1
- nat/eval/dataset_handler/dataset_filter.py +1 -1
- nat/eval/dataset_handler/dataset_handler.py +4 -2
- nat/eval/evaluate.py +226 -81
- nat/eval/evaluator/__init__.py +1 -1
- nat/eval/evaluator/base_evaluator.py +2 -2
- nat/eval/evaluator/evaluator_model.py +3 -2
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/llm_validator.py +336 -0
- nat/eval/rag_evaluator/evaluate.py +17 -10
- nat/eval/rag_evaluator/register.py +1 -1
- nat/eval/red_teaming_evaluator/__init__.py +14 -0
- nat/eval/red_teaming_evaluator/data_models.py +66 -0
- nat/eval/red_teaming_evaluator/evaluate.py +327 -0
- nat/eval/red_teaming_evaluator/filter_conditions.py +75 -0
- nat/eval/red_teaming_evaluator/register.py +55 -0
- nat/eval/register.py +2 -1
- nat/eval/remote_workflow.py +1 -1
- nat/eval/runners/__init__.py +1 -1
- nat/eval/runners/config.py +1 -1
- nat/eval/runners/multi_eval_runner.py +1 -1
- nat/eval/runners/red_teaming_runner/__init__.py +24 -0
- nat/eval/runners/red_teaming_runner/config.py +282 -0
- nat/eval/runners/red_teaming_runner/report_utils.py +707 -0
- nat/eval/runners/red_teaming_runner/runner.py +867 -0
- nat/eval/runtime_evaluator/__init__.py +1 -1
- nat/eval/runtime_evaluator/evaluate.py +1 -1
- nat/eval/runtime_evaluator/register.py +1 -1
- nat/eval/runtime_event_subscriber.py +1 -1
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/swe_bench_evaluator/register.py +1 -1
- nat/eval/trajectory_evaluator/evaluate.py +2 -2
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +5 -5
- nat/eval/tunable_rag_evaluator/register.py +1 -1
- nat/eval/usage_stats.py +1 -1
- nat/eval/utils/eval_trace_ctx.py +1 -1
- nat/eval/utils/output_uploader.py +1 -1
- nat/eval/utils/tqdm_position_registry.py +1 -1
- nat/eval/utils/weave_eval.py +1 -1
- nat/experimental/decorators/experimental_warning_decorator.py +1 -1
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +1 -1
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +1 -1
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +1 -1
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/multi_llm_judge_function.py +88 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/editor_config.py +1 -1
- nat/experimental/test_time_compute/models/scoring_config.py +1 -1
- nat/experimental/test_time_compute/models/search_config.py +20 -2
- nat/experimental/test_time_compute/models/selection_config.py +33 -2
- nat/experimental/test_time_compute/models/stage_enums.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +1 -1
- nat/experimental/test_time_compute/models/tool_use_config.py +1 -1
- nat/experimental/test_time_compute/models/ttc_item.py +1 -1
- nat/experimental/test_time_compute/register.py +4 -1
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +1 -1
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +1 -1
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +1 -1
- nat/experimental/test_time_compute/search/multi_llm_generation.py +115 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +1 -1
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +1 -1
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +1 -1
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_judge_selection.py +127 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +1 -1
- nat/finetuning/__init__.py +24 -0
- nat/finetuning/finetuning_runtime.py +143 -0
- nat/finetuning/interfaces/__init__.py +24 -0
- nat/finetuning/interfaces/finetuning_runner.py +261 -0
- nat/finetuning/interfaces/trainer_adapter.py +103 -0
- nat/finetuning/interfaces/trajectory_builder.py +115 -0
- nat/finetuning/utils/__init__.py +15 -0
- nat/finetuning/utils/parsers/__init__.py +15 -0
- nat/finetuning/utils/parsers/adk_parser.py +141 -0
- nat/finetuning/utils/parsers/base_parser.py +238 -0
- nat/finetuning/utils/parsers/common.py +91 -0
- nat/finetuning/utils/parsers/langchain_parser.py +267 -0
- nat/finetuning/utils/parsers/llama_index_parser.py +218 -0
- nat/front_ends/__init__.py +1 -1
- nat/front_ends/console/__init__.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +1 -1
- nat/front_ends/console/console_front_end_config.py +4 -1
- nat/front_ends/console/console_front_end_plugin.py +5 -4
- nat/front_ends/console/register.py +1 -1
- nat/front_ends/cron/__init__.py +1 -1
- nat/front_ends/fastapi/__init__.py +1 -1
- nat/front_ends/fastapi/async_job.py +128 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +13 -9
- nat/front_ends/fastapi/dask_client_mixin.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_config.py +23 -1
- nat/front_ends/fastapi/fastapi_front_end_controller.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +25 -30
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +318 -59
- nat/front_ends/fastapi/html_snippets/__init__.py +1 -1
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +1 -1
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +12 -1
- nat/front_ends/fastapi/job_store.py +23 -11
- nat/front_ends/fastapi/main.py +1 -1
- nat/front_ends/fastapi/message_handler.py +27 -4
- nat/front_ends/fastapi/message_validator.py +54 -2
- nat/front_ends/fastapi/register.py +1 -1
- nat/front_ends/fastapi/response_helpers.py +16 -15
- nat/front_ends/fastapi/step_adaptor.py +1 -1
- nat/front_ends/fastapi/utils.py +1 -1
- nat/front_ends/register.py +1 -2
- nat/front_ends/simple_base/__init__.py +1 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +6 -4
- nat/llm/aws_bedrock_llm.py +1 -1
- nat/llm/azure_openai_llm.py +10 -1
- nat/llm/dynamo_llm.py +363 -0
- nat/llm/huggingface_llm.py +177 -0
- nat/llm/litellm_llm.py +1 -1
- nat/llm/nim_llm.py +1 -1
- nat/llm/openai_llm.py +1 -1
- nat/llm/register.py +3 -1
- nat/llm/utils/__init__.py +1 -1
- nat/llm/utils/env_config_value.py +1 -1
- nat/llm/utils/error.py +1 -1
- nat/llm/utils/thinking.py +1 -1
- nat/memory/__init__.py +1 -1
- nat/memory/interfaces.py +1 -1
- nat/memory/models.py +1 -1
- nat/meta/pypi.md +1 -1
- nat/middleware/__init__.py +35 -0
- nat/middleware/cache/__init__.py +14 -0
- nat/middleware/cache/cache_middleware.py +253 -0
- nat/middleware/cache/cache_middleware_config.py +44 -0
- nat/middleware/cache/register.py +33 -0
- nat/middleware/defense/__init__.py +14 -0
- nat/middleware/defense/defense_middleware.py +362 -0
- nat/middleware/defense/defense_middleware_content_guard.py +455 -0
- nat/middleware/defense/defense_middleware_data_models.py +91 -0
- nat/middleware/defense/defense_middleware_output_verifier.py +440 -0
- nat/middleware/defense/defense_middleware_pii.py +356 -0
- nat/middleware/defense/register.py +82 -0
- nat/middleware/dynamic/__init__.py +14 -0
- nat/middleware/dynamic/dynamic_function_middleware.py +962 -0
- nat/middleware/dynamic/dynamic_middleware_config.py +132 -0
- nat/middleware/dynamic/register.py +34 -0
- nat/middleware/function_middleware.py +370 -0
- nat/middleware/logging/__init__.py +14 -0
- nat/middleware/logging/logging_middleware.py +67 -0
- nat/middleware/logging/logging_middleware_config.py +28 -0
- nat/middleware/logging/register.py +33 -0
- nat/middleware/middleware.py +298 -0
- nat/middleware/red_teaming/__init__.py +14 -0
- nat/middleware/red_teaming/red_teaming_middleware.py +344 -0
- nat/middleware/red_teaming/red_teaming_middleware_config.py +112 -0
- nat/middleware/red_teaming/register.py +47 -0
- nat/middleware/register.py +22 -0
- nat/middleware/utils/__init__.py +14 -0
- nat/middleware/utils/workflow_inventory.py +155 -0
- nat/object_store/__init__.py +1 -1
- nat/object_store/in_memory_object_store.py +1 -1
- nat/object_store/interfaces.py +1 -1
- nat/object_store/models.py +1 -1
- nat/object_store/register.py +1 -1
- nat/observability/__init__.py +1 -1
- nat/observability/exporter/__init__.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/exporter.py +1 -1
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +1 -1
- nat/observability/exporter/raw_exporter.py +1 -1
- nat/observability/exporter/span_exporter.py +7 -1
- nat/observability/exporter_manager.py +1 -1
- nat/observability/mixin/__init__.py +1 -1
- nat/observability/mixin/batch_config_mixin.py +1 -1
- nat/observability/mixin/collector_config_mixin.py +1 -1
- nat/observability/mixin/file_mixin.py +1 -1
- nat/observability/mixin/file_mode.py +1 -1
- nat/observability/mixin/redaction_config_mixin.py +1 -1
- nat/observability/mixin/resource_conflict_mixin.py +1 -1
- nat/observability/mixin/serialize_mixin.py +1 -1
- nat/observability/mixin/tagging_config_mixin.py +1 -1
- nat/observability/mixin/type_introspection_mixin.py +1 -1
- nat/observability/processor/__init__.py +1 -1
- nat/observability/processor/batching_processor.py +1 -1
- nat/observability/processor/callback_processor.py +1 -1
- nat/observability/processor/falsy_batch_filter_processor.py +1 -1
- nat/observability/processor/intermediate_step_serializer.py +1 -1
- nat/observability/processor/processor.py +1 -1
- nat/observability/processor/processor_factory.py +1 -1
- nat/observability/processor/redaction/__init__.py +1 -1
- nat/observability/processor/redaction/contextual_redaction_processor.py +1 -1
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +1 -1
- nat/observability/processor/redaction/redaction_processor.py +1 -1
- nat/observability/processor/redaction/span_header_redaction_processor.py +1 -1
- nat/observability/processor/span_tagging_processor.py +1 -1
- nat/observability/register.py +1 -1
- nat/observability/utils/__init__.py +1 -1
- nat/observability/utils/dict_utils.py +1 -1
- nat/observability/utils/time_utils.py +1 -1
- nat/profiler/calc/__init__.py +1 -1
- nat/profiler/calc/calc_runner.py +3 -3
- nat/profiler/calc/calculations.py +1 -1
- nat/profiler/calc/data_models.py +1 -1
- nat/profiler/calc/plot.py +30 -3
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/base_callback_class.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +33 -3
- nat/profiler/callbacks/llama_index_callback_handler.py +13 -10
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
- nat/profiler/callbacks/token_usage_base_model.py +1 -1
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/data_models.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +32 -1
- nat/profiler/decorators/function_tracking.py +1 -1
- nat/profiler/forecasting/config.py +1 -1
- nat/profiler/forecasting/model_trainer.py +1 -1
- nat/profiler/forecasting/models/__init__.py +1 -1
- nat/profiler/forecasting/models/forecasting_base_model.py +1 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_metrics_model.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +1 -1
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +1 -1
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
- nat/profiler/inference_optimization/llm_metrics.py +1 -1
- nat/profiler/inference_optimization/prompt_caching.py +1 -1
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/inference_optimization/workflow_runtimes.py +1 -1
- nat/profiler/intermediate_property_adapter.py +1 -1
- nat/profiler/parameter_optimization/optimizable_utils.py +1 -1
- nat/profiler/parameter_optimization/optimizer_runtime.py +1 -1
- nat/profiler/parameter_optimization/parameter_optimizer.py +1 -1
- nat/profiler/parameter_optimization/parameter_selection.py +1 -1
- nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
- nat/profiler/parameter_optimization/prompt_optimizer.py +1 -1
- nat/profiler/parameter_optimization/update_helpers.py +1 -1
- nat/profiler/profile_runner.py +1 -1
- nat/profiler/utils.py +1 -1
- nat/registry_handlers/local/local_handler.py +1 -1
- nat/registry_handlers/local/register_local.py +1 -1
- nat/registry_handlers/metadata_factory.py +1 -1
- nat/registry_handlers/package_utils.py +1 -1
- nat/registry_handlers/pypi/pypi_handler.py +1 -1
- nat/registry_handlers/pypi/register_pypi.py +1 -1
- nat/registry_handlers/register.py +1 -1
- nat/registry_handlers/registry_handler_base.py +1 -1
- nat/registry_handlers/rest/register_rest.py +1 -1
- nat/registry_handlers/rest/rest_handler.py +1 -1
- nat/registry_handlers/schemas/headers.py +1 -1
- nat/registry_handlers/schemas/package.py +1 -1
- nat/registry_handlers/schemas/publish.py +1 -1
- nat/registry_handlers/schemas/pull.py +1 -1
- nat/registry_handlers/schemas/remove.py +1 -1
- nat/registry_handlers/schemas/search.py +1 -1
- nat/registry_handlers/schemas/status.py +1 -1
- nat/retriever/interface.py +1 -1
- nat/retriever/milvus/__init__.py +1 -1
- nat/retriever/milvus/register.py +12 -4
- nat/retriever/milvus/retriever.py +103 -41
- nat/retriever/models.py +1 -1
- nat/retriever/nemo_retriever/__init__.py +1 -1
- nat/retriever/nemo_retriever/register.py +1 -1
- nat/retriever/nemo_retriever/retriever.py +5 -5
- nat/retriever/register.py +1 -1
- nat/runtime/__init__.py +1 -1
- nat/runtime/loader.py +10 -3
- nat/runtime/metrics.py +180 -0
- nat/runtime/runner.py +13 -6
- nat/runtime/session.py +458 -32
- nat/runtime/user_metadata.py +1 -1
- nat/settings/global_settings.py +1 -1
- nat/tool/chat_completion.py +1 -1
- nat/tool/code_execution/README.md +1 -1
- nat/tool/code_execution/code_sandbox.py +2 -2
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +1 -1
- nat/tool/code_execution/local_sandbox/__init__.py +1 -1
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +1 -1
- nat/tool/code_execution/register.py +1 -1
- nat/tool/code_execution/utils.py +1 -1
- nat/tool/datetime_tools.py +1 -1
- nat/tool/document_search.py +1 -1
- nat/tool/github_tools.py +1 -1
- nat/tool/memory_tools/add_memory_tool.py +1 -1
- nat/tool/memory_tools/delete_memory_tool.py +1 -1
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/tool/nvidia_rag.py +2 -2
- nat/tool/register.py +1 -1
- nat/tool/retriever.py +1 -1
- nat/tool/server_tools.py +1 -1
- nat/utils/__init__.py +8 -5
- nat/utils/callable_utils.py +1 -1
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/debugging_utils.py +1 -1
- nat/utils/decorators.py +1 -1
- nat/utils/dump_distro_mapping.py +1 -1
- nat/utils/exception_handlers/automatic_retries.py +3 -3
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/model_processing.py +1 -1
- nat/utils/io/supress_logs.py +33 -0
- nat/utils/io/yaml_tools.py +1 -1
- nat/utils/log_levels.py +1 -1
- nat/utils/log_utils.py +13 -1
- nat/utils/metadata_utils.py +1 -1
- nat/utils/optional_imports.py +1 -1
- nat/utils/producer_consumer_queue.py +1 -1
- nat/utils/reactive/base/observable_base.py +1 -1
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/base/subject_base.py +1 -1
- nat/utils/reactive/observable.py +1 -1
- nat/utils/reactive/observer.py +1 -1
- nat/utils/reactive/subject.py +1 -1
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/responses_api.py +1 -1
- nat/utils/settings/global_settings.py +1 -1
- nat/utils/string_utils.py +1 -1
- nat/utils/type_converter.py +18 -5
- nat/utils/type_utils.py +1 -1
- nat/utils/url_utils.py +1 -1
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/METADATA +46 -15
- nvidia_nat-1.4.0a20260113.dist-info/RECORD +547 -0
- nvidia_nat-1.4.0a20260113.dist-info/entry_points.txt +38 -0
- nat/cli/commands/mcp/mcp.py +0 -986
- nat/front_ends/mcp/introspection_token_verifier.py +0 -73
- nat/front_ends/mcp/mcp_front_end_config.py +0 -109
- nat/front_ends/mcp/mcp_front_end_plugin.py +0 -151
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +0 -362
- nat/front_ends/mcp/memory_profiler.py +0 -320
- nat/front_ends/mcp/register.py +0 -27
- nat/front_ends/mcp/tool_converter.py +0 -321
- nvidia_nat-1.4.0a20251112.dist-info/RECORD +0 -481
- nvidia_nat-1.4.0a20251112.dist-info/entry_points.txt +0 -22
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Base middleware class for the NeMo Agent toolkit.
|
|
16
|
+
|
|
17
|
+
This module provides the base Middleware class that defines the middleware pattern
|
|
18
|
+
for wrapping and modifying function calls. Middleware works like middleware in
|
|
19
|
+
web frameworks - they can modify inputs, call the next middleware in the chain,
|
|
20
|
+
process outputs, and continue.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import dataclasses
|
|
26
|
+
from abc import ABC
|
|
27
|
+
from abc import abstractmethod
|
|
28
|
+
from collections.abc import AsyncIterator
|
|
29
|
+
from collections.abc import Awaitable
|
|
30
|
+
from collections.abc import Callable
|
|
31
|
+
from typing import Any
|
|
32
|
+
|
|
33
|
+
from pydantic import BaseModel
|
|
34
|
+
from pydantic import ConfigDict
|
|
35
|
+
from pydantic import Field
|
|
36
|
+
|
|
37
|
+
#: Type alias for single-output invocation callables.
|
|
38
|
+
CallNext = Callable[..., Awaitable[Any]]
|
|
39
|
+
|
|
40
|
+
#: Type alias for streaming invocation callables.
|
|
41
|
+
CallNextStream = Callable[..., AsyncIterator[Any]]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
45
|
+
class FunctionMiddlewareContext:
|
|
46
|
+
"""Static metadata about the function being wrapped by middleware.
|
|
47
|
+
|
|
48
|
+
Middleware receives this context object which describes the function they
|
|
49
|
+
are wrapping. This allows middleware to make decisions based on the
|
|
50
|
+
function's name, configuration, schema, etc.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
name: str
|
|
54
|
+
"""Name of the function being wrapped."""
|
|
55
|
+
|
|
56
|
+
config: Any
|
|
57
|
+
"""Configuration object for the function."""
|
|
58
|
+
|
|
59
|
+
description: str | None
|
|
60
|
+
"""Optional description of the function."""
|
|
61
|
+
|
|
62
|
+
input_schema: type[BaseModel] | None
|
|
63
|
+
"""Schema describing expected inputs or :class:`NoneType` when absent."""
|
|
64
|
+
|
|
65
|
+
single_output_schema: type[BaseModel] | type[None]
|
|
66
|
+
"""Schema describing single outputs or :class:`types.NoneType` when absent."""
|
|
67
|
+
|
|
68
|
+
stream_output_schema: type[BaseModel] | type[None]
|
|
69
|
+
"""Schema describing streaming outputs or :class:`types.NoneType` when absent."""
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class InvocationContext(BaseModel):
|
|
73
|
+
"""Unified context for pre-invoke and post-invoke phases.
|
|
74
|
+
|
|
75
|
+
Used for both phases of middleware execution:
|
|
76
|
+
- Pre-invoke: output is None, modify modified_args/modified_kwargs to transform inputs
|
|
77
|
+
- Post-invoke: output contains the function result, modify output to transform results
|
|
78
|
+
|
|
79
|
+
This unified context simplifies the middleware interface by using a single
|
|
80
|
+
context type for both hooks.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
model_config = ConfigDict(validate_assignment=True)
|
|
84
|
+
|
|
85
|
+
# Frozen fields - cannot be modified after creation
|
|
86
|
+
function_context: FunctionMiddlewareContext = Field(
|
|
87
|
+
frozen=True, description="Static metadata about the function being invoked (frozen).")
|
|
88
|
+
original_args: tuple[Any, ...] = Field(
|
|
89
|
+
frozen=True, description="The original function input arguments before any middleware processing.")
|
|
90
|
+
original_kwargs: dict[str, Any] = Field(
|
|
91
|
+
frozen=True, description="The original function input keyword arguments before any middleware processing.")
|
|
92
|
+
|
|
93
|
+
# Mutable fields - modify these to transform inputs/outputs
|
|
94
|
+
modified_args: tuple[Any, ...] = Field(description="Modified args after middleware processing.")
|
|
95
|
+
modified_kwargs: dict[str, Any] = Field(description="Modified kwargs after middleware processing.")
|
|
96
|
+
output: Any = Field(default=None, description="Function output. None pre-invoke, result post-invoke.")
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class Middleware(ABC):
|
|
100
|
+
"""Base class for middleware-style wrapping with pre/post-invoke hooks.
|
|
101
|
+
|
|
102
|
+
Middleware works like middleware in web frameworks:
|
|
103
|
+
|
|
104
|
+
1. **Preprocess**: Inspect and optionally modify inputs (via pre_invoke)
|
|
105
|
+
2. **Call Next**: Delegate to the next middleware or the target itself
|
|
106
|
+
3. **Postprocess**: Process, transform, or augment the output (via post_invoke)
|
|
107
|
+
4. **Continue**: Return or yield the final result
|
|
108
|
+
|
|
109
|
+
Example::
|
|
110
|
+
|
|
111
|
+
class LoggingMiddleware(FunctionMiddleware):
|
|
112
|
+
@property
|
|
113
|
+
def enabled(self) -> bool:
|
|
114
|
+
return True
|
|
115
|
+
|
|
116
|
+
async def pre_invoke(self, context: InvocationContext) -> InvocationContext | None:
|
|
117
|
+
print(f"Current args: {context.modified_args}")
|
|
118
|
+
print(f"Original args: {context.original_args}")
|
|
119
|
+
return None # Pass through unchanged
|
|
120
|
+
|
|
121
|
+
async def post_invoke(self, context: InvocationContext) -> InvocationContext | None:
|
|
122
|
+
print(f"Output: {context.output}")
|
|
123
|
+
return None # Pass through unchanged
|
|
124
|
+
|
|
125
|
+
Attributes:
|
|
126
|
+
is_final: If True, this middleware terminates the chain. No subsequent
|
|
127
|
+
middleware or the target will be called unless this middleware
|
|
128
|
+
explicitly delegates to ``call_next``.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def __init__(self, *, is_final: bool = False) -> None:
|
|
132
|
+
self._is_final = is_final
|
|
133
|
+
|
|
134
|
+
# ==================== Abstract Members ====================
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
@abstractmethod
|
|
138
|
+
def enabled(self) -> bool:
|
|
139
|
+
"""Whether this middleware should execute.
|
|
140
|
+
"""
|
|
141
|
+
...
|
|
142
|
+
|
|
143
|
+
@abstractmethod
|
|
144
|
+
async def pre_invoke(self, context: InvocationContext) -> InvocationContext | None:
|
|
145
|
+
"""Transform inputs before execution.
|
|
146
|
+
|
|
147
|
+
Called by specialized middleware invoke methods (e.g., function_middleware_invoke).
|
|
148
|
+
Use to validate, transform, or augment inputs. At this phase, context.output is None.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
context: Invocation context (Pydantic model) containing:
|
|
152
|
+
- function_context: Static function metadata (frozen)
|
|
153
|
+
- original_args: What entered the middleware chain (frozen)
|
|
154
|
+
- original_kwargs: What entered the middleware chain (frozen)
|
|
155
|
+
- modified_args: Current args (mutable)
|
|
156
|
+
- modified_kwargs: Current kwargs (mutable)
|
|
157
|
+
- output: None (function not yet called)
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
InvocationContext: Return the (modified) context to signal changes
|
|
161
|
+
None: Pass through unchanged (framework uses current context state)
|
|
162
|
+
|
|
163
|
+
Note:
|
|
164
|
+
Frozen fields (original_args, original_kwargs) cannot be modified.
|
|
165
|
+
Attempting to modify them raises ValidationError.
|
|
166
|
+
|
|
167
|
+
Raises:
|
|
168
|
+
Any exception to abort execution
|
|
169
|
+
"""
|
|
170
|
+
...
|
|
171
|
+
|
|
172
|
+
@abstractmethod
|
|
173
|
+
async def post_invoke(self, context: InvocationContext) -> InvocationContext | None:
|
|
174
|
+
"""Transform output after execution.
|
|
175
|
+
|
|
176
|
+
Called by specialized middleware invoke methods (e.g., function_middleware_invoke).
|
|
177
|
+
For streaming, called per-chunk. Use to validate, transform, or augment outputs.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
context: Invocation context (Pydantic model) containing:
|
|
181
|
+
- function_context: Static function metadata (frozen)
|
|
182
|
+
- original_args: What entered the middleware chain (frozen)
|
|
183
|
+
- original_kwargs: What entered the middleware chain (frozen)
|
|
184
|
+
- modified_args: What the function received (mutable)
|
|
185
|
+
- modified_kwargs: What the function received (mutable)
|
|
186
|
+
- output: Current output value (mutable)
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
InvocationContext: Return the (modified) context to signal changes
|
|
190
|
+
None: Pass through unchanged (framework uses current context.output)
|
|
191
|
+
|
|
192
|
+
Example::
|
|
193
|
+
|
|
194
|
+
async def post_invoke(self, context: InvocationContext) -> InvocationContext | None:
|
|
195
|
+
# Wrap the output
|
|
196
|
+
context.output = {"result": context.output, "processed": True}
|
|
197
|
+
return context # Signal modification
|
|
198
|
+
|
|
199
|
+
Raises:
|
|
200
|
+
Any exception to abort and propagate error
|
|
201
|
+
"""
|
|
202
|
+
...
|
|
203
|
+
|
|
204
|
+
# ==================== Properties ====================
|
|
205
|
+
|
|
206
|
+
@property
|
|
207
|
+
def is_final(self) -> bool:
|
|
208
|
+
"""Whether this middleware terminates the chain.
|
|
209
|
+
|
|
210
|
+
A final middleware prevents subsequent middleware and the target
|
|
211
|
+
from running unless it explicitly calls ``call_next``.
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
return self._is_final
|
|
215
|
+
|
|
216
|
+
# ==================== Default Invoke Methods ====================
|
|
217
|
+
|
|
218
|
+
async def middleware_invoke(self,
|
|
219
|
+
value: Any,
|
|
220
|
+
call_next: CallNext,
|
|
221
|
+
context: FunctionMiddlewareContext,
|
|
222
|
+
**kwargs: Any) -> Any:
|
|
223
|
+
"""Middleware for single-output invocations.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
value: The input value to process
|
|
227
|
+
call_next: Callable to invoke the next middleware or target
|
|
228
|
+
context: Metadata about the target being wrapped
|
|
229
|
+
kwargs: Additional function arguments
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
The (potentially modified) output from the target
|
|
233
|
+
|
|
234
|
+
The default implementation simply delegates to ``call_next``. Override this
|
|
235
|
+
to add preprocessing, postprocessing, or to short-circuit execution::
|
|
236
|
+
|
|
237
|
+
async def middleware_invoke(self, value, call_next, context, **kwargs):
|
|
238
|
+
# Preprocess: modify input
|
|
239
|
+
modified_input = transform(value)
|
|
240
|
+
|
|
241
|
+
# Call next: delegate to next middleware/target
|
|
242
|
+
result = await call_next(modified_input, **kwargs)
|
|
243
|
+
|
|
244
|
+
# Postprocess: modify output
|
|
245
|
+
modified_result = transform_output(result)
|
|
246
|
+
|
|
247
|
+
# Continue: return final result
|
|
248
|
+
return modified_result
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
del context # Unused by the default implementation.
|
|
252
|
+
return await call_next(value, **kwargs)
|
|
253
|
+
|
|
254
|
+
async def middleware_stream(self,
|
|
255
|
+
value: Any,
|
|
256
|
+
call_next: CallNextStream,
|
|
257
|
+
context: FunctionMiddlewareContext,
|
|
258
|
+
**kwargs: Any) -> AsyncIterator[Any]:
|
|
259
|
+
"""Middleware for streaming invocations.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
value: The input value to process
|
|
263
|
+
call_next: Callable to invoke the next middleware or target stream
|
|
264
|
+
context: Metadata about the target being wrapped
|
|
265
|
+
kwargs: Additional function arguments
|
|
266
|
+
|
|
267
|
+
Yields:
|
|
268
|
+
Chunks from the stream (potentially modified)
|
|
269
|
+
|
|
270
|
+
The default implementation forwards to ``call_next`` untouched. Override this
|
|
271
|
+
to add preprocessing, transform chunks, or perform cleanup::
|
|
272
|
+
|
|
273
|
+
async def middleware_stream(self, value, call_next, context, **kwargs):
|
|
274
|
+
# Preprocess: setup or modify input
|
|
275
|
+
modified_input = transform(value)
|
|
276
|
+
|
|
277
|
+
# Call next: get stream from next middleware/target
|
|
278
|
+
async for chunk in call_next(modified_input, **kwargs):
|
|
279
|
+
# Process each chunk
|
|
280
|
+
modified_chunk = transform_chunk(chunk)
|
|
281
|
+
yield modified_chunk
|
|
282
|
+
|
|
283
|
+
# Postprocess: cleanup after stream ends
|
|
284
|
+
await cleanup()
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
del context # Unused by the default implementation.
|
|
288
|
+
async for chunk in call_next(value, **kwargs):
|
|
289
|
+
yield chunk
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
__all__ = [
|
|
293
|
+
"CallNext",
|
|
294
|
+
"CallNextStream",
|
|
295
|
+
"FunctionMiddlewareContext",
|
|
296
|
+
"InvocationContext",
|
|
297
|
+
"Middleware",
|
|
298
|
+
]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Red teaming middleware for attacking agent functions.
|
|
16
|
+
|
|
17
|
+
This module provides a middleware for red teaming and security testing that can
|
|
18
|
+
intercept and modify function inputs or outputs with configurable attack payloads.
|
|
19
|
+
|
|
20
|
+
The middleware supports:
|
|
21
|
+
- Targeting specific functions or entire function groups
|
|
22
|
+
- Field-level search within input/output schemas
|
|
23
|
+
- Multiple attack modes (replace, append_start, append_middle, append_end)
|
|
24
|
+
- Both regular and streaming function calls
|
|
25
|
+
- Type-safe operations on strings, integers, and floats
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
import logging
|
|
31
|
+
import random
|
|
32
|
+
import re
|
|
33
|
+
from typing import Any
|
|
34
|
+
from typing import Literal
|
|
35
|
+
from typing import cast
|
|
36
|
+
|
|
37
|
+
from jsonpath_ng import parse
|
|
38
|
+
from pydantic import BaseModel
|
|
39
|
+
|
|
40
|
+
from nat.middleware.function_middleware import CallNext
|
|
41
|
+
from nat.middleware.function_middleware import FunctionMiddleware
|
|
42
|
+
from nat.middleware.function_middleware import FunctionMiddlewareContext
|
|
43
|
+
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class RedTeamingMiddleware(FunctionMiddleware):
|
|
48
|
+
"""Middleware for red teaming that intercepts and modifies function inputs/outputs.
|
|
49
|
+
|
|
50
|
+
This middleware enables systematic security testing by injecting attack payloads
|
|
51
|
+
into function inputs or outputs. It supports flexible targeting, field-level
|
|
52
|
+
modifications, and multiple attack modes.
|
|
53
|
+
|
|
54
|
+
Features:
|
|
55
|
+
|
|
56
|
+
* Target specific functions or entire function groups
|
|
57
|
+
* Search for specific fields in input/output schemas
|
|
58
|
+
* Apply attacks via replace or append modes
|
|
59
|
+
* Support for both regular and streaming calls
|
|
60
|
+
* Type-safe operations on strings, numbers
|
|
61
|
+
|
|
62
|
+
Example::
|
|
63
|
+
|
|
64
|
+
# In YAML config
|
|
65
|
+
middleware:
|
|
66
|
+
prompt_injection:
|
|
67
|
+
_type: red_teaming
|
|
68
|
+
attack_payload: "Ignore previous instructions"
|
|
69
|
+
target_function_or_group: my_llm.generate
|
|
70
|
+
payload_placement: append_start
|
|
71
|
+
target_location: input
|
|
72
|
+
target_field: prompt
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
attack_payload: The malicious payload to inject.
|
|
76
|
+
target_function_or_group: Function or group to target (None for all).
|
|
77
|
+
payload_placement: How to apply (replace, append_start, append_middle, append_end).
|
|
78
|
+
target_location: Whether to attack input or output.
|
|
79
|
+
target_field: Field name or path to attack (None for direct value).
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
*,
|
|
85
|
+
attack_payload: str,
|
|
86
|
+
target_function_or_group: str | None = None,
|
|
87
|
+
payload_placement: Literal["replace", "append_start", "append_middle", "append_end"] = "append_end",
|
|
88
|
+
target_location: Literal["input", "output"] = "input",
|
|
89
|
+
target_field: str | None = None,
|
|
90
|
+
target_field_resolution_strategy: Literal["random", "first", "last", "all", "error"] = "error",
|
|
91
|
+
call_limit: int | None = None,
|
|
92
|
+
) -> None:
|
|
93
|
+
"""Initialize red teaming middleware.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
attack_payload: The value to inject to the function input or output.
|
|
97
|
+
target_function_or_group: Optional function/group to target.
|
|
98
|
+
payload_placement: How to apply the payload (replace or append modes).
|
|
99
|
+
target_location: Whether to place the payload in the input or output.
|
|
100
|
+
target_field: JSONPath to the field to attack.
|
|
101
|
+
target_field_resolution_strategy: Strategy (random/first/last/all/error).
|
|
102
|
+
call_limit: Maximum number of times the middleware will apply a payload.
|
|
103
|
+
"""
|
|
104
|
+
super().__init__(is_final=False)
|
|
105
|
+
self._attack_payload = attack_payload
|
|
106
|
+
self._target_function_or_group = target_function_or_group
|
|
107
|
+
self._payload_placement = payload_placement
|
|
108
|
+
self._target_location = target_location
|
|
109
|
+
self._target_field = target_field
|
|
110
|
+
self._target_field_resolution_strategy = target_field_resolution_strategy
|
|
111
|
+
self._call_count: int = 0 # Count the number of times the middleware has applied a payload
|
|
112
|
+
self._call_limit = call_limit
|
|
113
|
+
logger.info(
|
|
114
|
+
"RedTeamingMiddleware initialized: payload=%s, target=%s, placement=%s, location=%s, field=%s",
|
|
115
|
+
attack_payload,
|
|
116
|
+
target_function_or_group,
|
|
117
|
+
payload_placement,
|
|
118
|
+
target_location,
|
|
119
|
+
target_field,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def _should_apply_payload(self, context_name: str) -> bool:
|
|
123
|
+
"""Check if this function should be attacked based on targeting configuration.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
context_name: The name of the function from context (e.g., "calculator.add")
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
True if the function should be attacked, False otherwise
|
|
130
|
+
"""
|
|
131
|
+
# If no target specified, attack all functions
|
|
132
|
+
if self._target_function_or_group is None:
|
|
133
|
+
return True
|
|
134
|
+
|
|
135
|
+
target = self._target_function_or_group
|
|
136
|
+
|
|
137
|
+
# Group targeting - match if context starts with the group name
|
|
138
|
+
# Handle both "group.function" and just "function" in context
|
|
139
|
+
if "." in context_name and "." not in target:
|
|
140
|
+
context_group = context_name.split(".", 1)[0]
|
|
141
|
+
return context_group == target
|
|
142
|
+
|
|
143
|
+
if context_name == "<workflow>":
|
|
144
|
+
return target in {"<workflow>", "workflow"}
|
|
145
|
+
|
|
146
|
+
# If context has no dot, match if it equals the target exactly
|
|
147
|
+
return context_name == target
|
|
148
|
+
|
|
149
|
+
def _find_middle_sentence_index(self, text: str) -> int:
|
|
150
|
+
"""Find the index to insert text at the middle sentence boundary.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
text: The text to analyze
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
The character index where the middle sentence ends
|
|
157
|
+
"""
|
|
158
|
+
# Find all sentence boundaries using regex
|
|
159
|
+
# Match sentence-ending punctuation followed by space/newline or end of string
|
|
160
|
+
sentence_pattern = r"[.!?](?:\s+|$)"
|
|
161
|
+
matches = list(re.finditer(sentence_pattern, text))
|
|
162
|
+
|
|
163
|
+
if not matches:
|
|
164
|
+
# No sentence boundaries found, insert at middle character
|
|
165
|
+
return len(text) // 2
|
|
166
|
+
|
|
167
|
+
# Find the sentence boundary closest to the middle
|
|
168
|
+
text_midpoint = len(text) // 2
|
|
169
|
+
closest_match = min(matches, key=lambda m: abs(m.end() - text_midpoint))
|
|
170
|
+
|
|
171
|
+
return closest_match.end()
|
|
172
|
+
|
|
173
|
+
def _apply_payload_to_simple_type(self,
|
|
174
|
+
original_value: list | str | int | float,
|
|
175
|
+
attack_payload: str,
|
|
176
|
+
payload_placement: str) -> Any:
|
|
177
|
+
"""Apply the attack payload to simple types (str, int, float) value.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
original_value: The original value to attack
|
|
181
|
+
attack_payload: The payload to inject
|
|
182
|
+
payload_placement: How to apply the payload
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
The modified value with attack applied
|
|
186
|
+
|
|
187
|
+
Raises:
|
|
188
|
+
ValueError: If attack cannot be applied due to type mismatch
|
|
189
|
+
"""
|
|
190
|
+
# Determine actual type from value if not provided
|
|
191
|
+
value_type = type(original_value)
|
|
192
|
+
|
|
193
|
+
# Handle string attacks
|
|
194
|
+
if value_type is str or isinstance(original_value, str):
|
|
195
|
+
original_str = str(original_value)
|
|
196
|
+
|
|
197
|
+
if payload_placement == "replace":
|
|
198
|
+
return attack_payload
|
|
199
|
+
elif payload_placement == "append_start":
|
|
200
|
+
return f"{attack_payload}{original_str}"
|
|
201
|
+
elif payload_placement == "append_end":
|
|
202
|
+
return f"{original_str}{attack_payload}"
|
|
203
|
+
elif payload_placement == "append_middle":
|
|
204
|
+
insert_index = self._find_middle_sentence_index(original_str)
|
|
205
|
+
return f"{original_str[:insert_index]}{attack_payload}{original_str[insert_index:]}"
|
|
206
|
+
else:
|
|
207
|
+
raise ValueError(f"Unknown payload placement: {payload_placement}")
|
|
208
|
+
|
|
209
|
+
# Handle int/float attacks
|
|
210
|
+
if isinstance(original_value, int | float):
|
|
211
|
+
# For numbers, only replace is allowed
|
|
212
|
+
if payload_placement != "replace":
|
|
213
|
+
logger.warning(
|
|
214
|
+
"Payload placement '%s' not supported for numeric types (int/float). "
|
|
215
|
+
"Falling back to 'replace' mode for field with value %s",
|
|
216
|
+
payload_placement,
|
|
217
|
+
original_value,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Convert payload to the appropriate numeric type
|
|
221
|
+
try:
|
|
222
|
+
if value_type is int or isinstance(original_value, int):
|
|
223
|
+
return int(attack_payload)
|
|
224
|
+
return float(attack_payload)
|
|
225
|
+
except (ValueError, TypeError) as e:
|
|
226
|
+
raise ValueError(f"Cannot convert attack payload '{attack_payload}' to {value_type.__name__}") from e
|
|
227
|
+
|
|
228
|
+
def _resolve_multiple_field_matches(self, matches):
|
|
229
|
+
if self._target_field_resolution_strategy == "error":
|
|
230
|
+
raise ValueError(f"Multiple matches found for target_field: {self._target_field}")
|
|
231
|
+
elif self._target_field_resolution_strategy == "random":
|
|
232
|
+
return [random.choice(matches)]
|
|
233
|
+
elif self._target_field_resolution_strategy == "first":
|
|
234
|
+
return [matches[0]]
|
|
235
|
+
elif self._target_field_resolution_strategy == "last":
|
|
236
|
+
return [matches[-1]]
|
|
237
|
+
elif self._target_field_resolution_strategy == "all":
|
|
238
|
+
return matches
|
|
239
|
+
else:
|
|
240
|
+
raise ValueError(f"Unknown target_field_resolution_strategy: {self._target_field_resolution_strategy}")
|
|
241
|
+
|
|
242
|
+
def _apply_payload_to_complex_type(self, value: list | dict | BaseModel) -> list | dict | BaseModel:
|
|
243
|
+
if self._target_field is None:
|
|
244
|
+
if isinstance(value, BaseModel):
|
|
245
|
+
value_details = value.model_dump_json()
|
|
246
|
+
else:
|
|
247
|
+
value_details = ""
|
|
248
|
+
additional_info = ("Additional info: A pydantic BaseModel with fields:" +
|
|
249
|
+
value_details if value_details else "")
|
|
250
|
+
raise ValueError("Applying an attack payload to complex type, requires a target_field. \n"
|
|
251
|
+
f"Input value: {value}.: {value_details}. {additional_info} \n"
|
|
252
|
+
"A target field can be specified in the middleware configuration as a jsonpath.")
|
|
253
|
+
|
|
254
|
+
# Convert BaseModel to dict for jsonpath processing
|
|
255
|
+
original_type = type(value)
|
|
256
|
+
is_basemodel = isinstance(value, BaseModel)
|
|
257
|
+
if is_basemodel:
|
|
258
|
+
value_to_modify = value.model_dump()
|
|
259
|
+
else:
|
|
260
|
+
value_to_modify = value
|
|
261
|
+
|
|
262
|
+
jsonpath_expr = parse(self._target_field)
|
|
263
|
+
matches = jsonpath_expr.find(value_to_modify)
|
|
264
|
+
if len(matches) == 0:
|
|
265
|
+
raise ValueError(f"No matches found for target_field: {self._target_field} in value: {value}")
|
|
266
|
+
if len(matches) > 1:
|
|
267
|
+
matches = self._resolve_multiple_field_matches(matches)
|
|
268
|
+
else:
|
|
269
|
+
matches = [matches[0]]
|
|
270
|
+
modified_values = [
|
|
271
|
+
self._apply_payload_to_simple_type(match.value, self._attack_payload, self._payload_placement)
|
|
272
|
+
for match in matches
|
|
273
|
+
]
|
|
274
|
+
for match, modified_value in zip(matches, modified_values):
|
|
275
|
+
match.full_path.update(value_to_modify, modified_value)
|
|
276
|
+
|
|
277
|
+
# Reconstruct BaseModel if original was BaseModel
|
|
278
|
+
if is_basemodel:
|
|
279
|
+
assert isinstance(value_to_modify, dict)
|
|
280
|
+
return cast(type[BaseModel], original_type)(**value_to_modify)
|
|
281
|
+
return value_to_modify
|
|
282
|
+
|
|
283
|
+
def _apply_payload_to_function_value(self, value: Any) -> Any:
|
|
284
|
+
if self._call_limit is not None and self._call_count >= self._call_limit:
|
|
285
|
+
logger.warning("Call limit reached for red teaming middleware. "
|
|
286
|
+
"Not applying attack payload to value: %s",
|
|
287
|
+
value)
|
|
288
|
+
return value
|
|
289
|
+
if isinstance(value, list | dict | BaseModel):
|
|
290
|
+
modified_value = self._apply_payload_to_complex_type(value)
|
|
291
|
+
elif isinstance(value, str | int | float):
|
|
292
|
+
modified_value = self._apply_payload_to_simple_type(value, self._attack_payload, self._payload_placement)
|
|
293
|
+
else:
|
|
294
|
+
raise ValueError(f"Unsupported function input/output type: {type(value).__name__}")
|
|
295
|
+
self._call_count += 1
|
|
296
|
+
return modified_value
|
|
297
|
+
|
|
298
|
+
def _apply_payload_to_function_value_with_exception(self, value: Any, context: FunctionMiddlewareContext) -> Any:
|
|
299
|
+
try:
|
|
300
|
+
return self._apply_payload_to_function_value(value)
|
|
301
|
+
except Exception as e:
|
|
302
|
+
logger.error("Failed to apply red team attack to function %s: %s", context.name, e, exc_info=True)
|
|
303
|
+
raise
|
|
304
|
+
|
|
305
|
+
async def function_middleware_invoke(self,
|
|
306
|
+
*args: Any,
|
|
307
|
+
call_next: CallNext,
|
|
308
|
+
context: FunctionMiddlewareContext,
|
|
309
|
+
**kwargs: Any) -> Any:
|
|
310
|
+
"""Invoke middleware for single-output functions.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
args: Positional arguments passed to the function (first arg is typically the input value).
|
|
314
|
+
call_next: Callable to invoke next middleware/function.
|
|
315
|
+
context: Metadata about the function being wrapped.
|
|
316
|
+
kwargs: Keyword arguments passed to the function.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
The output value (potentially modified if attacking output).
|
|
320
|
+
"""
|
|
321
|
+
value = args[0] if args else None
|
|
322
|
+
|
|
323
|
+
# Check if we should attack this function
|
|
324
|
+
if not self._should_apply_payload(context.name):
|
|
325
|
+
logger.debug("Skipping function %s (not targeted)", context.name)
|
|
326
|
+
return await call_next(value, *args[1:], **kwargs)
|
|
327
|
+
|
|
328
|
+
if self._target_location == "input":
|
|
329
|
+
# Attack the input before calling the function
|
|
330
|
+
modified_input = self._apply_payload_to_function_value_with_exception(value, context)
|
|
331
|
+
# Call next with modified input
|
|
332
|
+
return await call_next(modified_input, *args[1:], **kwargs)
|
|
333
|
+
|
|
334
|
+
elif self._target_location == "output": # target_location == "output"
|
|
335
|
+
# Call function first, then attack the output
|
|
336
|
+
output = await call_next(value, *args[1:], **kwargs)
|
|
337
|
+
modified_output = self._apply_payload_to_function_value_with_exception(output, context)
|
|
338
|
+
return modified_output
|
|
339
|
+
else:
|
|
340
|
+
raise ValueError(f"Unknown target_location: {self._target_location}. "
|
|
341
|
+
"Attack payloads can only be applied to function input or output.")
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
__all__ = ["RedTeamingMiddleware"]
|