nvidia-nat 1.4.0a20251120__py3-none-any.whl → 1.4.0a20260113__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +1 -1
- nat/{front_ends/mcp → agent/auto_memory_wrapper}/__init__.py +1 -1
- nat/agent/auto_memory_wrapper/agent.py +278 -0
- nat/agent/auto_memory_wrapper/register.py +227 -0
- nat/agent/auto_memory_wrapper/state.py +30 -0
- nat/agent/base.py +1 -1
- nat/agent/dual_node.py +1 -1
- nat/agent/prompt_optimizer/prompt.py +1 -1
- nat/agent/prompt_optimizer/register.py +1 -1
- nat/agent/react_agent/agent.py +16 -9
- nat/agent/react_agent/output_parser.py +2 -2
- nat/agent/react_agent/prompt.py +3 -2
- nat/agent/react_agent/register.py +2 -2
- nat/agent/react_agent/register_per_user_agent.py +104 -0
- nat/agent/reasoning_agent/reasoning_agent.py +1 -1
- nat/agent/register.py +3 -1
- nat/agent/responses_api_agent/__init__.py +1 -1
- nat/agent/responses_api_agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +9 -4
- nat/agent/rewoo_agent/prompt.py +1 -1
- nat/agent/rewoo_agent/register.py +1 -1
- nat/agent/tool_calling_agent/agent.py +5 -4
- nat/agent/tool_calling_agent/register.py +1 -1
- nat/authentication/__init__.py +1 -1
- nat/authentication/api_key/__init__.py +1 -1
- nat/authentication/api_key/api_key_auth_provider.py +1 -1
- nat/authentication/api_key/api_key_auth_provider_config.py +22 -7
- nat/authentication/api_key/register.py +1 -1
- nat/authentication/credential_validator/__init__.py +1 -1
- nat/authentication/credential_validator/bearer_token_validator.py +1 -1
- nat/authentication/exceptions/__init__.py +1 -1
- nat/authentication/exceptions/api_key_exceptions.py +1 -1
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/http_basic_auth/register.py +1 -1
- nat/authentication/interfaces.py +1 -1
- nat/authentication/oauth2/__init__.py +1 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +1 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +1 -1
- nat/authentication/oauth2/oauth2_resource_server_config.py +1 -1
- nat/authentication/oauth2/register.py +1 -1
- nat/authentication/register.py +1 -1
- nat/builder/builder.py +511 -1
- nat/builder/child_builder.py +385 -0
- nat/builder/component_utils.py +28 -4
- nat/builder/context.py +17 -1
- nat/builder/embedder.py +1 -1
- nat/builder/eval_builder.py +19 -7
- nat/builder/evaluator.py +1 -1
- nat/builder/framework_enum.py +2 -1
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +40 -3
- nat/builder/function_base.py +1 -1
- nat/builder/function_info.py +1 -1
- nat/builder/intermediate_step_manager.py +1 -1
- nat/builder/llm.py +1 -1
- nat/builder/per_user_workflow_builder.py +843 -0
- nat/builder/retriever.py +1 -1
- nat/builder/sync_builder.py +571 -0
- nat/builder/user_interaction_manager.py +1 -1
- nat/builder/workflow.py +1 -1
- nat/builder/workflow_builder.py +536 -424
- nat/cli/__init__.py +1 -1
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/cli_utils/validation.py +32 -1
- nat/cli/commands/configure/channel/add.py +1 -1
- nat/cli/commands/configure/channel/channel.py +1 -1
- nat/cli/commands/configure/channel/remove.py +1 -1
- nat/cli/commands/configure/channel/update.py +1 -1
- nat/cli/commands/configure/configure.py +1 -1
- nat/cli/commands/evaluate.py +87 -13
- nat/cli/commands/finetune.py +132 -0
- nat/cli/commands/info/__init__.py +1 -1
- nat/cli/commands/info/info.py +1 -1
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +1 -1
- nat/cli/commands/object_store/__init__.py +1 -1
- nat/cli/commands/object_store/object_store.py +1 -1
- nat/cli/commands/optimize.py +1 -1
- nat/cli/commands/{mcp → red_teaming}/__init__.py +1 -1
- nat/cli/commands/red_teaming/red_teaming.py +138 -0
- nat/cli/commands/red_teaming/red_teaming_utils.py +73 -0
- nat/cli/commands/registry/__init__.py +1 -1
- nat/cli/commands/registry/publish.py +1 -1
- nat/cli/commands/registry/pull.py +1 -1
- nat/cli/commands/registry/registry.py +1 -1
- nat/cli/commands/registry/remove.py +1 -1
- nat/cli/commands/registry/search.py +1 -1
- nat/cli/commands/sizing/__init__.py +1 -1
- nat/cli/commands/sizing/calc.py +1 -1
- nat/cli/commands/sizing/sizing.py +1 -1
- nat/cli/commands/start.py +1 -1
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/validate.py +1 -1
- nat/cli/commands/workflow/__init__.py +1 -1
- nat/cli/commands/workflow/workflow.py +1 -1
- nat/cli/commands/workflow/workflow_commands.py +3 -2
- nat/cli/entrypoint.py +15 -37
- nat/cli/main.py +2 -2
- nat/cli/plugin_loader.py +69 -0
- nat/cli/register_workflow.py +183 -5
- nat/cli/type_registry.py +169 -3
- nat/control_flow/register.py +1 -1
- nat/control_flow/router_agent/agent.py +1 -1
- nat/control_flow/router_agent/prompt.py +1 -1
- nat/control_flow/router_agent/register.py +1 -1
- nat/control_flow/sequential_executor.py +28 -7
- nat/data_models/__init__.py +1 -1
- nat/data_models/agent.py +1 -1
- nat/data_models/api_server.py +38 -3
- nat/data_models/authentication.py +1 -1
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +7 -1
- nat/data_models/component_ref.py +34 -1
- nat/data_models/config.py +62 -1
- nat/data_models/dataset_handler.py +15 -2
- nat/data_models/discovery_metadata.py +1 -1
- nat/data_models/embedder.py +1 -1
- nat/data_models/evaluate.py +6 -1
- nat/data_models/evaluator.py +1 -1
- nat/data_models/finetuning.py +260 -0
- nat/data_models/front_end.py +1 -1
- nat/data_models/function.py +1 -1
- nat/data_models/function_dependencies.py +1 -1
- nat/data_models/gated_field_mixin.py +1 -1
- nat/data_models/interactive.py +1 -1
- nat/data_models/intermediate_step.py +29 -2
- nat/data_models/invocation_node.py +1 -1
- nat/data_models/llm.py +1 -1
- nat/data_models/logging.py +1 -1
- nat/data_models/memory.py +1 -1
- nat/data_models/middleware.py +3 -1
- nat/data_models/object_store.py +1 -1
- nat/data_models/openai_mcp.py +1 -1
- nat/data_models/optimizable.py +1 -1
- nat/data_models/optimizer.py +1 -1
- nat/data_models/profiler.py +1 -1
- nat/data_models/registry_handler.py +1 -1
- nat/data_models/retriever.py +1 -1
- nat/data_models/retry_mixin.py +1 -1
- nat/data_models/runtime_enum.py +1 -1
- nat/data_models/span.py +1 -1
- nat/data_models/step_adaptor.py +1 -1
- nat/data_models/streaming.py +1 -1
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/telemetry_exporter.py +1 -1
- nat/data_models/thinking_mixin.py +1 -1
- nat/data_models/ttc_strategy.py +1 -1
- nat/embedder/azure_openai_embedder.py +1 -1
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +1 -1
- nat/eval/__init__.py +1 -1
- nat/eval/config.py +8 -1
- nat/eval/dataset_handler/dataset_downloader.py +1 -1
- nat/eval/dataset_handler/dataset_filter.py +1 -1
- nat/eval/dataset_handler/dataset_handler.py +4 -2
- nat/eval/evaluate.py +217 -80
- nat/eval/evaluator/__init__.py +1 -1
- nat/eval/evaluator/base_evaluator.py +2 -2
- nat/eval/evaluator/evaluator_model.py +3 -2
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/llm_validator.py +336 -0
- nat/eval/rag_evaluator/evaluate.py +17 -10
- nat/eval/rag_evaluator/register.py +1 -1
- nat/eval/red_teaming_evaluator/__init__.py +14 -0
- nat/eval/red_teaming_evaluator/data_models.py +66 -0
- nat/eval/red_teaming_evaluator/evaluate.py +327 -0
- nat/eval/red_teaming_evaluator/filter_conditions.py +75 -0
- nat/eval/red_teaming_evaluator/register.py +55 -0
- nat/eval/register.py +2 -1
- nat/eval/remote_workflow.py +1 -1
- nat/eval/runners/__init__.py +1 -1
- nat/eval/runners/config.py +1 -1
- nat/eval/runners/multi_eval_runner.py +1 -1
- nat/eval/runners/red_teaming_runner/__init__.py +24 -0
- nat/eval/runners/red_teaming_runner/config.py +282 -0
- nat/eval/runners/red_teaming_runner/report_utils.py +707 -0
- nat/eval/runners/red_teaming_runner/runner.py +867 -0
- nat/eval/runtime_evaluator/__init__.py +1 -1
- nat/eval/runtime_evaluator/evaluate.py +1 -1
- nat/eval/runtime_evaluator/register.py +1 -1
- nat/eval/runtime_event_subscriber.py +1 -1
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/swe_bench_evaluator/register.py +1 -1
- nat/eval/trajectory_evaluator/evaluate.py +2 -2
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +5 -5
- nat/eval/tunable_rag_evaluator/register.py +1 -1
- nat/eval/usage_stats.py +1 -1
- nat/eval/utils/eval_trace_ctx.py +1 -1
- nat/eval/utils/output_uploader.py +1 -1
- nat/eval/utils/tqdm_position_registry.py +1 -1
- nat/eval/utils/weave_eval.py +1 -1
- nat/experimental/decorators/experimental_warning_decorator.py +1 -1
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +1 -1
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +1 -1
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +1 -1
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/multi_llm_judge_function.py +88 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/editor_config.py +1 -1
- nat/experimental/test_time_compute/models/scoring_config.py +1 -1
- nat/experimental/test_time_compute/models/search_config.py +20 -2
- nat/experimental/test_time_compute/models/selection_config.py +33 -2
- nat/experimental/test_time_compute/models/stage_enums.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +1 -1
- nat/experimental/test_time_compute/models/tool_use_config.py +1 -1
- nat/experimental/test_time_compute/models/ttc_item.py +1 -1
- nat/experimental/test_time_compute/register.py +4 -1
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +1 -1
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +1 -1
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +1 -1
- nat/experimental/test_time_compute/search/multi_llm_generation.py +115 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +1 -1
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +1 -1
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +1 -1
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_judge_selection.py +127 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +1 -1
- nat/finetuning/__init__.py +24 -0
- nat/finetuning/finetuning_runtime.py +143 -0
- nat/finetuning/interfaces/__init__.py +24 -0
- nat/finetuning/interfaces/finetuning_runner.py +261 -0
- nat/finetuning/interfaces/trainer_adapter.py +103 -0
- nat/finetuning/interfaces/trajectory_builder.py +115 -0
- nat/finetuning/utils/__init__.py +15 -0
- nat/finetuning/utils/parsers/__init__.py +15 -0
- nat/finetuning/utils/parsers/adk_parser.py +141 -0
- nat/finetuning/utils/parsers/base_parser.py +238 -0
- nat/finetuning/utils/parsers/common.py +91 -0
- nat/finetuning/utils/parsers/langchain_parser.py +267 -0
- nat/finetuning/utils/parsers/llama_index_parser.py +218 -0
- nat/front_ends/__init__.py +1 -1
- nat/front_ends/console/__init__.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +1 -1
- nat/front_ends/console/console_front_end_config.py +4 -1
- nat/front_ends/console/console_front_end_plugin.py +5 -4
- nat/front_ends/console/register.py +1 -1
- nat/front_ends/cron/__init__.py +1 -1
- nat/front_ends/fastapi/__init__.py +1 -1
- nat/front_ends/fastapi/async_job.py +128 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +13 -9
- nat/front_ends/fastapi/dask_client_mixin.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_config.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_controller.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +25 -30
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +195 -60
- nat/front_ends/fastapi/html_snippets/__init__.py +1 -1
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +1 -1
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +12 -1
- nat/front_ends/fastapi/job_store.py +23 -11
- nat/front_ends/fastapi/main.py +1 -1
- nat/front_ends/fastapi/message_handler.py +27 -4
- nat/front_ends/fastapi/message_validator.py +54 -2
- nat/front_ends/fastapi/register.py +1 -1
- nat/front_ends/fastapi/response_helpers.py +16 -15
- nat/front_ends/fastapi/step_adaptor.py +1 -1
- nat/front_ends/fastapi/utils.py +1 -1
- nat/front_ends/register.py +1 -2
- nat/front_ends/simple_base/__init__.py +1 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +6 -4
- nat/llm/aws_bedrock_llm.py +1 -1
- nat/llm/azure_openai_llm.py +10 -1
- nat/llm/dynamo_llm.py +363 -0
- nat/llm/huggingface_llm.py +177 -0
- nat/llm/litellm_llm.py +1 -1
- nat/llm/nim_llm.py +1 -1
- nat/llm/openai_llm.py +1 -1
- nat/llm/register.py +3 -1
- nat/llm/utils/__init__.py +1 -1
- nat/llm/utils/env_config_value.py +1 -1
- nat/llm/utils/error.py +1 -1
- nat/llm/utils/thinking.py +1 -1
- nat/memory/__init__.py +1 -1
- nat/memory/interfaces.py +1 -1
- nat/memory/models.py +1 -1
- nat/meta/pypi.md +1 -1
- nat/middleware/__init__.py +5 -5
- nat/middleware/cache/__init__.py +14 -0
- nat/middleware/{cache_middleware.py → cache/cache_middleware.py} +39 -42
- nat/middleware/cache/cache_middleware_config.py +44 -0
- nat/middleware/cache/register.py +33 -0
- nat/middleware/defense/__init__.py +14 -0
- nat/middleware/defense/defense_middleware.py +362 -0
- nat/middleware/defense/defense_middleware_content_guard.py +455 -0
- nat/middleware/defense/defense_middleware_data_models.py +91 -0
- nat/middleware/defense/defense_middleware_output_verifier.py +440 -0
- nat/middleware/defense/defense_middleware_pii.py +356 -0
- nat/middleware/defense/register.py +82 -0
- nat/middleware/dynamic/__init__.py +14 -0
- nat/middleware/dynamic/dynamic_function_middleware.py +962 -0
- nat/middleware/dynamic/dynamic_middleware_config.py +132 -0
- nat/middleware/dynamic/register.py +34 -0
- nat/middleware/function_middleware.py +236 -52
- nat/middleware/logging/__init__.py +14 -0
- nat/middleware/logging/logging_middleware.py +67 -0
- nat/middleware/logging/logging_middleware_config.py +28 -0
- nat/middleware/logging/register.py +33 -0
- nat/middleware/middleware.py +142 -28
- nat/middleware/red_teaming/__init__.py +14 -0
- nat/middleware/red_teaming/red_teaming_middleware.py +344 -0
- nat/middleware/red_teaming/red_teaming_middleware_config.py +112 -0
- nat/middleware/red_teaming/register.py +47 -0
- nat/middleware/register.py +7 -20
- nat/middleware/utils/__init__.py +14 -0
- nat/middleware/utils/workflow_inventory.py +155 -0
- nat/object_store/__init__.py +1 -1
- nat/object_store/in_memory_object_store.py +1 -1
- nat/object_store/interfaces.py +1 -1
- nat/object_store/models.py +1 -1
- nat/object_store/register.py +1 -1
- nat/observability/__init__.py +1 -1
- nat/observability/exporter/__init__.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/exporter.py +1 -1
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +1 -1
- nat/observability/exporter/raw_exporter.py +1 -1
- nat/observability/exporter/span_exporter.py +7 -1
- nat/observability/exporter_manager.py +1 -1
- nat/observability/mixin/__init__.py +1 -1
- nat/observability/mixin/batch_config_mixin.py +1 -1
- nat/observability/mixin/collector_config_mixin.py +1 -1
- nat/observability/mixin/file_mixin.py +1 -1
- nat/observability/mixin/file_mode.py +1 -1
- nat/observability/mixin/redaction_config_mixin.py +1 -1
- nat/observability/mixin/resource_conflict_mixin.py +1 -1
- nat/observability/mixin/serialize_mixin.py +1 -1
- nat/observability/mixin/tagging_config_mixin.py +1 -1
- nat/observability/mixin/type_introspection_mixin.py +1 -1
- nat/observability/processor/__init__.py +1 -1
- nat/observability/processor/batching_processor.py +1 -1
- nat/observability/processor/callback_processor.py +1 -1
- nat/observability/processor/falsy_batch_filter_processor.py +1 -1
- nat/observability/processor/intermediate_step_serializer.py +1 -1
- nat/observability/processor/processor.py +1 -1
- nat/observability/processor/processor_factory.py +1 -1
- nat/observability/processor/redaction/__init__.py +1 -1
- nat/observability/processor/redaction/contextual_redaction_processor.py +1 -1
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +1 -1
- nat/observability/processor/redaction/redaction_processor.py +1 -1
- nat/observability/processor/redaction/span_header_redaction_processor.py +1 -1
- nat/observability/processor/span_tagging_processor.py +1 -1
- nat/observability/register.py +1 -1
- nat/observability/utils/__init__.py +1 -1
- nat/observability/utils/dict_utils.py +1 -1
- nat/observability/utils/time_utils.py +1 -1
- nat/profiler/calc/__init__.py +1 -1
- nat/profiler/calc/calc_runner.py +3 -3
- nat/profiler/calc/calculations.py +1 -1
- nat/profiler/calc/data_models.py +1 -1
- nat/profiler/calc/plot.py +30 -3
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/base_callback_class.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +33 -3
- nat/profiler/callbacks/llama_index_callback_handler.py +13 -10
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
- nat/profiler/callbacks/token_usage_base_model.py +1 -1
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/data_models.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +16 -1
- nat/profiler/decorators/function_tracking.py +1 -1
- nat/profiler/forecasting/config.py +1 -1
- nat/profiler/forecasting/model_trainer.py +1 -1
- nat/profiler/forecasting/models/__init__.py +1 -1
- nat/profiler/forecasting/models/forecasting_base_model.py +1 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_metrics_model.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +1 -1
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +1 -1
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
- nat/profiler/inference_optimization/llm_metrics.py +1 -1
- nat/profiler/inference_optimization/prompt_caching.py +1 -1
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/inference_optimization/workflow_runtimes.py +1 -1
- nat/profiler/intermediate_property_adapter.py +1 -1
- nat/profiler/parameter_optimization/optimizable_utils.py +1 -1
- nat/profiler/parameter_optimization/optimizer_runtime.py +1 -1
- nat/profiler/parameter_optimization/parameter_optimizer.py +1 -1
- nat/profiler/parameter_optimization/parameter_selection.py +1 -1
- nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
- nat/profiler/parameter_optimization/prompt_optimizer.py +1 -1
- nat/profiler/parameter_optimization/update_helpers.py +1 -1
- nat/profiler/profile_runner.py +1 -1
- nat/profiler/utils.py +1 -1
- nat/registry_handlers/local/local_handler.py +1 -1
- nat/registry_handlers/local/register_local.py +1 -1
- nat/registry_handlers/metadata_factory.py +1 -1
- nat/registry_handlers/package_utils.py +1 -1
- nat/registry_handlers/pypi/pypi_handler.py +1 -1
- nat/registry_handlers/pypi/register_pypi.py +1 -1
- nat/registry_handlers/register.py +1 -1
- nat/registry_handlers/registry_handler_base.py +1 -1
- nat/registry_handlers/rest/register_rest.py +1 -1
- nat/registry_handlers/rest/rest_handler.py +1 -1
- nat/registry_handlers/schemas/headers.py +1 -1
- nat/registry_handlers/schemas/package.py +1 -1
- nat/registry_handlers/schemas/publish.py +1 -1
- nat/registry_handlers/schemas/pull.py +1 -1
- nat/registry_handlers/schemas/remove.py +1 -1
- nat/registry_handlers/schemas/search.py +1 -1
- nat/registry_handlers/schemas/status.py +1 -1
- nat/retriever/interface.py +1 -1
- nat/retriever/milvus/__init__.py +1 -1
- nat/retriever/milvus/register.py +1 -1
- nat/retriever/milvus/retriever.py +1 -1
- nat/retriever/models.py +1 -1
- nat/retriever/nemo_retriever/__init__.py +1 -1
- nat/retriever/nemo_retriever/register.py +1 -1
- nat/retriever/nemo_retriever/retriever.py +5 -5
- nat/retriever/register.py +1 -1
- nat/runtime/__init__.py +1 -1
- nat/runtime/loader.py +10 -3
- nat/runtime/metrics.py +180 -0
- nat/runtime/runner.py +1 -5
- nat/runtime/session.py +451 -32
- nat/runtime/user_metadata.py +1 -1
- nat/settings/global_settings.py +1 -1
- nat/tool/chat_completion.py +1 -1
- nat/tool/code_execution/README.md +1 -1
- nat/tool/code_execution/code_sandbox.py +1 -1
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +1 -1
- nat/tool/code_execution/local_sandbox/__init__.py +1 -1
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +1 -1
- nat/tool/code_execution/register.py +1 -1
- nat/tool/code_execution/utils.py +1 -1
- nat/tool/datetime_tools.py +1 -1
- nat/tool/document_search.py +1 -1
- nat/tool/github_tools.py +1 -1
- nat/tool/memory_tools/add_memory_tool.py +1 -1
- nat/tool/memory_tools/delete_memory_tool.py +1 -1
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/tool/nvidia_rag.py +2 -2
- nat/tool/register.py +1 -1
- nat/tool/retriever.py +1 -1
- nat/tool/server_tools.py +1 -1
- nat/utils/__init__.py +8 -5
- nat/utils/callable_utils.py +1 -1
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/debugging_utils.py +1 -1
- nat/utils/decorators.py +1 -1
- nat/utils/dump_distro_mapping.py +1 -1
- nat/utils/exception_handlers/automatic_retries.py +3 -3
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/model_processing.py +1 -1
- nat/utils/io/supress_logs.py +33 -0
- nat/utils/io/yaml_tools.py +1 -1
- nat/utils/log_levels.py +1 -1
- nat/utils/log_utils.py +13 -1
- nat/utils/metadata_utils.py +1 -1
- nat/utils/optional_imports.py +1 -1
- nat/utils/producer_consumer_queue.py +1 -1
- nat/utils/reactive/base/observable_base.py +1 -1
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/base/subject_base.py +1 -1
- nat/utils/reactive/observable.py +1 -1
- nat/utils/reactive/observer.py +1 -1
- nat/utils/reactive/subject.py +1 -1
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/responses_api.py +1 -1
- nat/utils/settings/global_settings.py +1 -1
- nat/utils/string_utils.py +1 -1
- nat/utils/type_converter.py +18 -5
- nat/utils/type_utils.py +1 -1
- nat/utils/url_utils.py +1 -1
- {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/METADATA +39 -14
- nvidia_nat-1.4.0a20260113.dist-info/RECORD +547 -0
- nvidia_nat-1.4.0a20260113.dist-info/entry_points.txt +38 -0
- nat/cli/commands/mcp/mcp.py +0 -986
- nat/front_ends/mcp/introspection_token_verifier.py +0 -73
- nat/front_ends/mcp/mcp_front_end_config.py +0 -109
- nat/front_ends/mcp/mcp_front_end_plugin.py +0 -155
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +0 -388
- nat/front_ends/mcp/memory_profiler.py +0 -320
- nat/front_ends/mcp/register.py +0 -27
- nat/front_ends/mcp/tool_converter.py +0 -321
- nvidia_nat-1.4.0a20251120.dist-info/RECORD +0 -488
- nvidia_nat-1.4.0a20251120.dist-info/entry_points.txt +0 -23
- {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/top_level.txt +0 -0
nat/middleware/middleware.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -24,23 +24,26 @@ from __future__ import annotations
|
|
|
24
24
|
|
|
25
25
|
import dataclasses
|
|
26
26
|
from abc import ABC
|
|
27
|
+
from abc import abstractmethod
|
|
27
28
|
from collections.abc import AsyncIterator
|
|
28
29
|
from collections.abc import Awaitable
|
|
29
30
|
from collections.abc import Callable
|
|
30
31
|
from typing import Any
|
|
31
32
|
|
|
32
33
|
from pydantic import BaseModel
|
|
34
|
+
from pydantic import ConfigDict
|
|
35
|
+
from pydantic import Field
|
|
33
36
|
|
|
34
37
|
#: Type alias for single-output invocation callables.
|
|
35
|
-
CallNext = Callable[
|
|
38
|
+
CallNext = Callable[..., Awaitable[Any]]
|
|
36
39
|
|
|
37
40
|
#: Type alias for streaming invocation callables.
|
|
38
|
-
CallNextStream = Callable[
|
|
41
|
+
CallNextStream = Callable[..., AsyncIterator[Any]]
|
|
39
42
|
|
|
40
43
|
|
|
41
44
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
42
45
|
class FunctionMiddlewareContext:
|
|
43
|
-
"""
|
|
46
|
+
"""Static metadata about the function being wrapped by middleware.
|
|
44
47
|
|
|
45
48
|
Middleware receives this context object which describes the function they
|
|
46
49
|
are wrapping. This allows middleware to make decisions based on the
|
|
@@ -66,31 +69,58 @@ class FunctionMiddlewareContext:
|
|
|
66
69
|
"""Schema describing streaming outputs or :class:`types.NoneType` when absent."""
|
|
67
70
|
|
|
68
71
|
|
|
72
|
+
class InvocationContext(BaseModel):
|
|
73
|
+
"""Unified context for pre-invoke and post-invoke phases.
|
|
74
|
+
|
|
75
|
+
Used for both phases of middleware execution:
|
|
76
|
+
- Pre-invoke: output is None, modify modified_args/modified_kwargs to transform inputs
|
|
77
|
+
- Post-invoke: output contains the function result, modify output to transform results
|
|
78
|
+
|
|
79
|
+
This unified context simplifies the middleware interface by using a single
|
|
80
|
+
context type for both hooks.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
model_config = ConfigDict(validate_assignment=True)
|
|
84
|
+
|
|
85
|
+
# Frozen fields - cannot be modified after creation
|
|
86
|
+
function_context: FunctionMiddlewareContext = Field(
|
|
87
|
+
frozen=True, description="Static metadata about the function being invoked (frozen).")
|
|
88
|
+
original_args: tuple[Any, ...] = Field(
|
|
89
|
+
frozen=True, description="The original function input arguments before any middleware processing.")
|
|
90
|
+
original_kwargs: dict[str, Any] = Field(
|
|
91
|
+
frozen=True, description="The original function input keyword arguments before any middleware processing.")
|
|
92
|
+
|
|
93
|
+
# Mutable fields - modify these to transform inputs/outputs
|
|
94
|
+
modified_args: tuple[Any, ...] = Field(description="Modified args after middleware processing.")
|
|
95
|
+
modified_kwargs: dict[str, Any] = Field(description="Modified kwargs after middleware processing.")
|
|
96
|
+
output: Any = Field(default=None, description="Function output. None pre-invoke, result post-invoke.")
|
|
97
|
+
|
|
98
|
+
|
|
69
99
|
class Middleware(ABC):
|
|
70
|
-
"""Base class for middleware-style wrapping.
|
|
100
|
+
"""Base class for middleware-style wrapping with pre/post-invoke hooks.
|
|
71
101
|
|
|
72
102
|
Middleware works like middleware in web frameworks:
|
|
73
103
|
|
|
74
|
-
1. **Preprocess**: Inspect and optionally modify inputs
|
|
104
|
+
1. **Preprocess**: Inspect and optionally modify inputs (via pre_invoke)
|
|
75
105
|
2. **Call Next**: Delegate to the next middleware or the target itself
|
|
76
|
-
3. **Postprocess**: Process, transform, or augment the output
|
|
106
|
+
3. **Postprocess**: Process, transform, or augment the output (via post_invoke)
|
|
77
107
|
4. **Continue**: Return or yield the final result
|
|
78
108
|
|
|
79
109
|
Example::
|
|
80
110
|
|
|
81
|
-
class LoggingMiddleware(
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
# 2. Call next middleware/target
|
|
87
|
-
result = await call_next(value)
|
|
111
|
+
class LoggingMiddleware(FunctionMiddleware):
|
|
112
|
+
@property
|
|
113
|
+
def enabled(self) -> bool:
|
|
114
|
+
return True
|
|
88
115
|
|
|
89
|
-
|
|
90
|
-
print(f"
|
|
116
|
+
async def pre_invoke(self, context: InvocationContext) -> InvocationContext | None:
|
|
117
|
+
print(f"Current args: {context.modified_args}")
|
|
118
|
+
print(f"Original args: {context.original_args}")
|
|
119
|
+
return None # Pass through unchanged
|
|
91
120
|
|
|
92
|
-
|
|
93
|
-
|
|
121
|
+
async def post_invoke(self, context: InvocationContext) -> InvocationContext | None:
|
|
122
|
+
print(f"Output: {context.output}")
|
|
123
|
+
return None # Pass through unchanged
|
|
94
124
|
|
|
95
125
|
Attributes:
|
|
96
126
|
is_final: If True, this middleware terminates the chain. No subsequent
|
|
@@ -101,6 +131,78 @@ class Middleware(ABC):
|
|
|
101
131
|
def __init__(self, *, is_final: bool = False) -> None:
|
|
102
132
|
self._is_final = is_final
|
|
103
133
|
|
|
134
|
+
# ==================== Abstract Members ====================
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
@abstractmethod
|
|
138
|
+
def enabled(self) -> bool:
|
|
139
|
+
"""Whether this middleware should execute.
|
|
140
|
+
"""
|
|
141
|
+
...
|
|
142
|
+
|
|
143
|
+
@abstractmethod
|
|
144
|
+
async def pre_invoke(self, context: InvocationContext) -> InvocationContext | None:
|
|
145
|
+
"""Transform inputs before execution.
|
|
146
|
+
|
|
147
|
+
Called by specialized middleware invoke methods (e.g., function_middleware_invoke).
|
|
148
|
+
Use to validate, transform, or augment inputs. At this phase, context.output is None.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
context: Invocation context (Pydantic model) containing:
|
|
152
|
+
- function_context: Static function metadata (frozen)
|
|
153
|
+
- original_args: What entered the middleware chain (frozen)
|
|
154
|
+
- original_kwargs: What entered the middleware chain (frozen)
|
|
155
|
+
- modified_args: Current args (mutable)
|
|
156
|
+
- modified_kwargs: Current kwargs (mutable)
|
|
157
|
+
- output: None (function not yet called)
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
InvocationContext: Return the (modified) context to signal changes
|
|
161
|
+
None: Pass through unchanged (framework uses current context state)
|
|
162
|
+
|
|
163
|
+
Note:
|
|
164
|
+
Frozen fields (original_args, original_kwargs) cannot be modified.
|
|
165
|
+
Attempting to modify them raises ValidationError.
|
|
166
|
+
|
|
167
|
+
Raises:
|
|
168
|
+
Any exception to abort execution
|
|
169
|
+
"""
|
|
170
|
+
...
|
|
171
|
+
|
|
172
|
+
@abstractmethod
|
|
173
|
+
async def post_invoke(self, context: InvocationContext) -> InvocationContext | None:
|
|
174
|
+
"""Transform output after execution.
|
|
175
|
+
|
|
176
|
+
Called by specialized middleware invoke methods (e.g., function_middleware_invoke).
|
|
177
|
+
For streaming, called per-chunk. Use to validate, transform, or augment outputs.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
context: Invocation context (Pydantic model) containing:
|
|
181
|
+
- function_context: Static function metadata (frozen)
|
|
182
|
+
- original_args: What entered the middleware chain (frozen)
|
|
183
|
+
- original_kwargs: What entered the middleware chain (frozen)
|
|
184
|
+
- modified_args: What the function received (mutable)
|
|
185
|
+
- modified_kwargs: What the function received (mutable)
|
|
186
|
+
- output: Current output value (mutable)
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
InvocationContext: Return the (modified) context to signal changes
|
|
190
|
+
None: Pass through unchanged (framework uses current context.output)
|
|
191
|
+
|
|
192
|
+
Example::
|
|
193
|
+
|
|
194
|
+
async def post_invoke(self, context: InvocationContext) -> InvocationContext | None:
|
|
195
|
+
# Wrap the output
|
|
196
|
+
context.output = {"result": context.output, "processed": True}
|
|
197
|
+
return context # Signal modification
|
|
198
|
+
|
|
199
|
+
Raises:
|
|
200
|
+
Any exception to abort and propagate error
|
|
201
|
+
"""
|
|
202
|
+
...
|
|
203
|
+
|
|
204
|
+
# ==================== Properties ====================
|
|
205
|
+
|
|
104
206
|
@property
|
|
105
207
|
def is_final(self) -> bool:
|
|
106
208
|
"""Whether this middleware terminates the chain.
|
|
@@ -111,13 +213,20 @@ class Middleware(ABC):
|
|
|
111
213
|
|
|
112
214
|
return self._is_final
|
|
113
215
|
|
|
114
|
-
|
|
216
|
+
# ==================== Default Invoke Methods ====================
|
|
217
|
+
|
|
218
|
+
async def middleware_invoke(self,
|
|
219
|
+
value: Any,
|
|
220
|
+
call_next: CallNext,
|
|
221
|
+
context: FunctionMiddlewareContext,
|
|
222
|
+
**kwargs: Any) -> Any:
|
|
115
223
|
"""Middleware for single-output invocations.
|
|
116
224
|
|
|
117
225
|
Args:
|
|
118
226
|
value: The input value to process
|
|
119
227
|
call_next: Callable to invoke the next middleware or target
|
|
120
228
|
context: Metadata about the target being wrapped
|
|
229
|
+
kwargs: Additional function arguments
|
|
121
230
|
|
|
122
231
|
Returns:
|
|
123
232
|
The (potentially modified) output from the target
|
|
@@ -125,12 +234,12 @@ class Middleware(ABC):
|
|
|
125
234
|
The default implementation simply delegates to ``call_next``. Override this
|
|
126
235
|
to add preprocessing, postprocessing, or to short-circuit execution::
|
|
127
236
|
|
|
128
|
-
async def middleware_invoke(self, value, call_next, context):
|
|
237
|
+
async def middleware_invoke(self, value, call_next, context, **kwargs):
|
|
129
238
|
# Preprocess: modify input
|
|
130
239
|
modified_input = transform(value)
|
|
131
240
|
|
|
132
241
|
# Call next: delegate to next middleware/target
|
|
133
|
-
result = await call_next(modified_input)
|
|
242
|
+
result = await call_next(modified_input, **kwargs)
|
|
134
243
|
|
|
135
244
|
# Postprocess: modify output
|
|
136
245
|
modified_result = transform_output(result)
|
|
@@ -140,16 +249,20 @@ class Middleware(ABC):
|
|
|
140
249
|
"""
|
|
141
250
|
|
|
142
251
|
del context # Unused by the default implementation.
|
|
143
|
-
return await call_next(value)
|
|
252
|
+
return await call_next(value, **kwargs)
|
|
144
253
|
|
|
145
|
-
async def middleware_stream(self,
|
|
146
|
-
|
|
254
|
+
async def middleware_stream(self,
|
|
255
|
+
value: Any,
|
|
256
|
+
call_next: CallNextStream,
|
|
257
|
+
context: FunctionMiddlewareContext,
|
|
258
|
+
**kwargs: Any) -> AsyncIterator[Any]:
|
|
147
259
|
"""Middleware for streaming invocations.
|
|
148
260
|
|
|
149
261
|
Args:
|
|
150
262
|
value: The input value to process
|
|
151
263
|
call_next: Callable to invoke the next middleware or target stream
|
|
152
264
|
context: Metadata about the target being wrapped
|
|
265
|
+
kwargs: Additional function arguments
|
|
153
266
|
|
|
154
267
|
Yields:
|
|
155
268
|
Chunks from the stream (potentially modified)
|
|
@@ -157,12 +270,12 @@ class Middleware(ABC):
|
|
|
157
270
|
The default implementation forwards to ``call_next`` untouched. Override this
|
|
158
271
|
to add preprocessing, transform chunks, or perform cleanup::
|
|
159
272
|
|
|
160
|
-
async def middleware_stream(self, value, call_next, context):
|
|
273
|
+
async def middleware_stream(self, value, call_next, context, **kwargs):
|
|
161
274
|
# Preprocess: setup or modify input
|
|
162
275
|
modified_input = transform(value)
|
|
163
276
|
|
|
164
277
|
# Call next: get stream from next middleware/target
|
|
165
|
-
async for chunk in call_next(modified_input):
|
|
278
|
+
async for chunk in call_next(modified_input, **kwargs):
|
|
166
279
|
# Process each chunk
|
|
167
280
|
modified_chunk = transform_chunk(chunk)
|
|
168
281
|
yield modified_chunk
|
|
@@ -172,13 +285,14 @@ class Middleware(ABC):
|
|
|
172
285
|
"""
|
|
173
286
|
|
|
174
287
|
del context # Unused by the default implementation.
|
|
175
|
-
async for chunk in call_next(value):
|
|
288
|
+
async for chunk in call_next(value, **kwargs):
|
|
176
289
|
yield chunk
|
|
177
290
|
|
|
178
291
|
|
|
179
292
|
__all__ = [
|
|
180
293
|
"CallNext",
|
|
181
294
|
"CallNextStream",
|
|
182
|
-
"Middleware",
|
|
183
295
|
"FunctionMiddlewareContext",
|
|
296
|
+
"InvocationContext",
|
|
297
|
+
"Middleware",
|
|
184
298
|
]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Red teaming middleware for attacking agent functions.
|
|
16
|
+
|
|
17
|
+
This module provides a middleware for red teaming and security testing that can
|
|
18
|
+
intercept and modify function inputs or outputs with configurable attack payloads.
|
|
19
|
+
|
|
20
|
+
The middleware supports:
|
|
21
|
+
- Targeting specific functions or entire function groups
|
|
22
|
+
- Field-level search within input/output schemas
|
|
23
|
+
- Multiple attack modes (replace, append_start, append_middle, append_end)
|
|
24
|
+
- Both regular and streaming function calls
|
|
25
|
+
- Type-safe operations on strings, integers, and floats
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
import logging
|
|
31
|
+
import random
|
|
32
|
+
import re
|
|
33
|
+
from typing import Any
|
|
34
|
+
from typing import Literal
|
|
35
|
+
from typing import cast
|
|
36
|
+
|
|
37
|
+
from jsonpath_ng import parse
|
|
38
|
+
from pydantic import BaseModel
|
|
39
|
+
|
|
40
|
+
from nat.middleware.function_middleware import CallNext
|
|
41
|
+
from nat.middleware.function_middleware import FunctionMiddleware
|
|
42
|
+
from nat.middleware.function_middleware import FunctionMiddlewareContext
|
|
43
|
+
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class RedTeamingMiddleware(FunctionMiddleware):
|
|
48
|
+
"""Middleware for red teaming that intercepts and modifies function inputs/outputs.
|
|
49
|
+
|
|
50
|
+
This middleware enables systematic security testing by injecting attack payloads
|
|
51
|
+
into function inputs or outputs. It supports flexible targeting, field-level
|
|
52
|
+
modifications, and multiple attack modes.
|
|
53
|
+
|
|
54
|
+
Features:
|
|
55
|
+
|
|
56
|
+
* Target specific functions or entire function groups
|
|
57
|
+
* Search for specific fields in input/output schemas
|
|
58
|
+
* Apply attacks via replace or append modes
|
|
59
|
+
* Support for both regular and streaming calls
|
|
60
|
+
* Type-safe operations on strings, numbers
|
|
61
|
+
|
|
62
|
+
Example::
|
|
63
|
+
|
|
64
|
+
# In YAML config
|
|
65
|
+
middleware:
|
|
66
|
+
prompt_injection:
|
|
67
|
+
_type: red_teaming
|
|
68
|
+
attack_payload: "Ignore previous instructions"
|
|
69
|
+
target_function_or_group: my_llm.generate
|
|
70
|
+
payload_placement: append_start
|
|
71
|
+
target_location: input
|
|
72
|
+
target_field: prompt
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
attack_payload: The malicious payload to inject.
|
|
76
|
+
target_function_or_group: Function or group to target (None for all).
|
|
77
|
+
payload_placement: How to apply (replace, append_start, append_middle, append_end).
|
|
78
|
+
target_location: Whether to attack input or output.
|
|
79
|
+
target_field: Field name or path to attack (None for direct value).
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
*,
|
|
85
|
+
attack_payload: str,
|
|
86
|
+
target_function_or_group: str | None = None,
|
|
87
|
+
payload_placement: Literal["replace", "append_start", "append_middle", "append_end"] = "append_end",
|
|
88
|
+
target_location: Literal["input", "output"] = "input",
|
|
89
|
+
target_field: str | None = None,
|
|
90
|
+
target_field_resolution_strategy: Literal["random", "first", "last", "all", "error"] = "error",
|
|
91
|
+
call_limit: int | None = None,
|
|
92
|
+
) -> None:
|
|
93
|
+
"""Initialize red teaming middleware.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
attack_payload: The value to inject to the function input or output.
|
|
97
|
+
target_function_or_group: Optional function/group to target.
|
|
98
|
+
payload_placement: How to apply the payload (replace or append modes).
|
|
99
|
+
target_location: Whether to place the payload in the input or output.
|
|
100
|
+
target_field: JSONPath to the field to attack.
|
|
101
|
+
target_field_resolution_strategy: Strategy (random/first/last/all/error).
|
|
102
|
+
call_limit: Maximum number of times the middleware will apply a payload.
|
|
103
|
+
"""
|
|
104
|
+
super().__init__(is_final=False)
|
|
105
|
+
self._attack_payload = attack_payload
|
|
106
|
+
self._target_function_or_group = target_function_or_group
|
|
107
|
+
self._payload_placement = payload_placement
|
|
108
|
+
self._target_location = target_location
|
|
109
|
+
self._target_field = target_field
|
|
110
|
+
self._target_field_resolution_strategy = target_field_resolution_strategy
|
|
111
|
+
self._call_count: int = 0 # Count the number of times the middleware has applied a payload
|
|
112
|
+
self._call_limit = call_limit
|
|
113
|
+
logger.info(
|
|
114
|
+
"RedTeamingMiddleware initialized: payload=%s, target=%s, placement=%s, location=%s, field=%s",
|
|
115
|
+
attack_payload,
|
|
116
|
+
target_function_or_group,
|
|
117
|
+
payload_placement,
|
|
118
|
+
target_location,
|
|
119
|
+
target_field,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def _should_apply_payload(self, context_name: str) -> bool:
|
|
123
|
+
"""Check if this function should be attacked based on targeting configuration.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
context_name: The name of the function from context (e.g., "calculator.add")
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
True if the function should be attacked, False otherwise
|
|
130
|
+
"""
|
|
131
|
+
# If no target specified, attack all functions
|
|
132
|
+
if self._target_function_or_group is None:
|
|
133
|
+
return True
|
|
134
|
+
|
|
135
|
+
target = self._target_function_or_group
|
|
136
|
+
|
|
137
|
+
# Group targeting - match if context starts with the group name
|
|
138
|
+
# Handle both "group.function" and just "function" in context
|
|
139
|
+
if "." in context_name and "." not in target:
|
|
140
|
+
context_group = context_name.split(".", 1)[0]
|
|
141
|
+
return context_group == target
|
|
142
|
+
|
|
143
|
+
if context_name == "<workflow>":
|
|
144
|
+
return target in {"<workflow>", "workflow"}
|
|
145
|
+
|
|
146
|
+
# If context has no dot, match if it equals the target exactly
|
|
147
|
+
return context_name == target
|
|
148
|
+
|
|
149
|
+
def _find_middle_sentence_index(self, text: str) -> int:
|
|
150
|
+
"""Find the index to insert text at the middle sentence boundary.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
text: The text to analyze
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
The character index where the middle sentence ends
|
|
157
|
+
"""
|
|
158
|
+
# Find all sentence boundaries using regex
|
|
159
|
+
# Match sentence-ending punctuation followed by space/newline or end of string
|
|
160
|
+
sentence_pattern = r"[.!?](?:\s+|$)"
|
|
161
|
+
matches = list(re.finditer(sentence_pattern, text))
|
|
162
|
+
|
|
163
|
+
if not matches:
|
|
164
|
+
# No sentence boundaries found, insert at middle character
|
|
165
|
+
return len(text) // 2
|
|
166
|
+
|
|
167
|
+
# Find the sentence boundary closest to the middle
|
|
168
|
+
text_midpoint = len(text) // 2
|
|
169
|
+
closest_match = min(matches, key=lambda m: abs(m.end() - text_midpoint))
|
|
170
|
+
|
|
171
|
+
return closest_match.end()
|
|
172
|
+
|
|
173
|
+
def _apply_payload_to_simple_type(self,
|
|
174
|
+
original_value: list | str | int | float,
|
|
175
|
+
attack_payload: str,
|
|
176
|
+
payload_placement: str) -> Any:
|
|
177
|
+
"""Apply the attack payload to simple types (str, int, float) value.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
original_value: The original value to attack
|
|
181
|
+
attack_payload: The payload to inject
|
|
182
|
+
payload_placement: How to apply the payload
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
The modified value with attack applied
|
|
186
|
+
|
|
187
|
+
Raises:
|
|
188
|
+
ValueError: If attack cannot be applied due to type mismatch
|
|
189
|
+
"""
|
|
190
|
+
# Determine actual type from value if not provided
|
|
191
|
+
value_type = type(original_value)
|
|
192
|
+
|
|
193
|
+
# Handle string attacks
|
|
194
|
+
if value_type is str or isinstance(original_value, str):
|
|
195
|
+
original_str = str(original_value)
|
|
196
|
+
|
|
197
|
+
if payload_placement == "replace":
|
|
198
|
+
return attack_payload
|
|
199
|
+
elif payload_placement == "append_start":
|
|
200
|
+
return f"{attack_payload}{original_str}"
|
|
201
|
+
elif payload_placement == "append_end":
|
|
202
|
+
return f"{original_str}{attack_payload}"
|
|
203
|
+
elif payload_placement == "append_middle":
|
|
204
|
+
insert_index = self._find_middle_sentence_index(original_str)
|
|
205
|
+
return f"{original_str[:insert_index]}{attack_payload}{original_str[insert_index:]}"
|
|
206
|
+
else:
|
|
207
|
+
raise ValueError(f"Unknown payload placement: {payload_placement}")
|
|
208
|
+
|
|
209
|
+
# Handle int/float attacks
|
|
210
|
+
if isinstance(original_value, int | float):
|
|
211
|
+
# For numbers, only replace is allowed
|
|
212
|
+
if payload_placement != "replace":
|
|
213
|
+
logger.warning(
|
|
214
|
+
"Payload placement '%s' not supported for numeric types (int/float). "
|
|
215
|
+
"Falling back to 'replace' mode for field with value %s",
|
|
216
|
+
payload_placement,
|
|
217
|
+
original_value,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Convert payload to the appropriate numeric type
|
|
221
|
+
try:
|
|
222
|
+
if value_type is int or isinstance(original_value, int):
|
|
223
|
+
return int(attack_payload)
|
|
224
|
+
return float(attack_payload)
|
|
225
|
+
except (ValueError, TypeError) as e:
|
|
226
|
+
raise ValueError(f"Cannot convert attack payload '{attack_payload}' to {value_type.__name__}") from e
|
|
227
|
+
|
|
228
|
+
def _resolve_multiple_field_matches(self, matches):
|
|
229
|
+
if self._target_field_resolution_strategy == "error":
|
|
230
|
+
raise ValueError(f"Multiple matches found for target_field: {self._target_field}")
|
|
231
|
+
elif self._target_field_resolution_strategy == "random":
|
|
232
|
+
return [random.choice(matches)]
|
|
233
|
+
elif self._target_field_resolution_strategy == "first":
|
|
234
|
+
return [matches[0]]
|
|
235
|
+
elif self._target_field_resolution_strategy == "last":
|
|
236
|
+
return [matches[-1]]
|
|
237
|
+
elif self._target_field_resolution_strategy == "all":
|
|
238
|
+
return matches
|
|
239
|
+
else:
|
|
240
|
+
raise ValueError(f"Unknown target_field_resolution_strategy: {self._target_field_resolution_strategy}")
|
|
241
|
+
|
|
242
|
+
def _apply_payload_to_complex_type(self, value: list | dict | BaseModel) -> list | dict | BaseModel:
|
|
243
|
+
if self._target_field is None:
|
|
244
|
+
if isinstance(value, BaseModel):
|
|
245
|
+
value_details = value.model_dump_json()
|
|
246
|
+
else:
|
|
247
|
+
value_details = ""
|
|
248
|
+
additional_info = ("Additional info: A pydantic BaseModel with fields:" +
|
|
249
|
+
value_details if value_details else "")
|
|
250
|
+
raise ValueError("Applying an attack payload to complex type, requires a target_field. \n"
|
|
251
|
+
f"Input value: {value}.: {value_details}. {additional_info} \n"
|
|
252
|
+
"A target field can be specified in the middleware configuration as a jsonpath.")
|
|
253
|
+
|
|
254
|
+
# Convert BaseModel to dict for jsonpath processing
|
|
255
|
+
original_type = type(value)
|
|
256
|
+
is_basemodel = isinstance(value, BaseModel)
|
|
257
|
+
if is_basemodel:
|
|
258
|
+
value_to_modify = value.model_dump()
|
|
259
|
+
else:
|
|
260
|
+
value_to_modify = value
|
|
261
|
+
|
|
262
|
+
jsonpath_expr = parse(self._target_field)
|
|
263
|
+
matches = jsonpath_expr.find(value_to_modify)
|
|
264
|
+
if len(matches) == 0:
|
|
265
|
+
raise ValueError(f"No matches found for target_field: {self._target_field} in value: {value}")
|
|
266
|
+
if len(matches) > 1:
|
|
267
|
+
matches = self._resolve_multiple_field_matches(matches)
|
|
268
|
+
else:
|
|
269
|
+
matches = [matches[0]]
|
|
270
|
+
modified_values = [
|
|
271
|
+
self._apply_payload_to_simple_type(match.value, self._attack_payload, self._payload_placement)
|
|
272
|
+
for match in matches
|
|
273
|
+
]
|
|
274
|
+
for match, modified_value in zip(matches, modified_values):
|
|
275
|
+
match.full_path.update(value_to_modify, modified_value)
|
|
276
|
+
|
|
277
|
+
# Reconstruct BaseModel if original was BaseModel
|
|
278
|
+
if is_basemodel:
|
|
279
|
+
assert isinstance(value_to_modify, dict)
|
|
280
|
+
return cast(type[BaseModel], original_type)(**value_to_modify)
|
|
281
|
+
return value_to_modify
|
|
282
|
+
|
|
283
|
+
def _apply_payload_to_function_value(self, value: Any) -> Any:
|
|
284
|
+
if self._call_limit is not None and self._call_count >= self._call_limit:
|
|
285
|
+
logger.warning("Call limit reached for red teaming middleware. "
|
|
286
|
+
"Not applying attack payload to value: %s",
|
|
287
|
+
value)
|
|
288
|
+
return value
|
|
289
|
+
if isinstance(value, list | dict | BaseModel):
|
|
290
|
+
modified_value = self._apply_payload_to_complex_type(value)
|
|
291
|
+
elif isinstance(value, str | int | float):
|
|
292
|
+
modified_value = self._apply_payload_to_simple_type(value, self._attack_payload, self._payload_placement)
|
|
293
|
+
else:
|
|
294
|
+
raise ValueError(f"Unsupported function input/output type: {type(value).__name__}")
|
|
295
|
+
self._call_count += 1
|
|
296
|
+
return modified_value
|
|
297
|
+
|
|
298
|
+
def _apply_payload_to_function_value_with_exception(self, value: Any, context: FunctionMiddlewareContext) -> Any:
|
|
299
|
+
try:
|
|
300
|
+
return self._apply_payload_to_function_value(value)
|
|
301
|
+
except Exception as e:
|
|
302
|
+
logger.error("Failed to apply red team attack to function %s: %s", context.name, e, exc_info=True)
|
|
303
|
+
raise
|
|
304
|
+
|
|
305
|
+
async def function_middleware_invoke(self,
|
|
306
|
+
*args: Any,
|
|
307
|
+
call_next: CallNext,
|
|
308
|
+
context: FunctionMiddlewareContext,
|
|
309
|
+
**kwargs: Any) -> Any:
|
|
310
|
+
"""Invoke middleware for single-output functions.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
args: Positional arguments passed to the function (first arg is typically the input value).
|
|
314
|
+
call_next: Callable to invoke next middleware/function.
|
|
315
|
+
context: Metadata about the function being wrapped.
|
|
316
|
+
kwargs: Keyword arguments passed to the function.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
The output value (potentially modified if attacking output).
|
|
320
|
+
"""
|
|
321
|
+
value = args[0] if args else None
|
|
322
|
+
|
|
323
|
+
# Check if we should attack this function
|
|
324
|
+
if not self._should_apply_payload(context.name):
|
|
325
|
+
logger.debug("Skipping function %s (not targeted)", context.name)
|
|
326
|
+
return await call_next(value, *args[1:], **kwargs)
|
|
327
|
+
|
|
328
|
+
if self._target_location == "input":
|
|
329
|
+
# Attack the input before calling the function
|
|
330
|
+
modified_input = self._apply_payload_to_function_value_with_exception(value, context)
|
|
331
|
+
# Call next with modified input
|
|
332
|
+
return await call_next(modified_input, *args[1:], **kwargs)
|
|
333
|
+
|
|
334
|
+
elif self._target_location == "output": # target_location == "output"
|
|
335
|
+
# Call function first, then attack the output
|
|
336
|
+
output = await call_next(value, *args[1:], **kwargs)
|
|
337
|
+
modified_output = self._apply_payload_to_function_value_with_exception(output, context)
|
|
338
|
+
return modified_output
|
|
339
|
+
else:
|
|
340
|
+
raise ValueError(f"Unknown target_location: {self._target_location}. "
|
|
341
|
+
"Attack payloads can only be applied to function input or output.")
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
__all__ = ["RedTeamingMiddleware"]
|