aiqtoolkit 1.2.0.dev0__py3-none-any.whl → 1.2.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/base.py +170 -8
- aiq/agent/dual_node.py +1 -1
- aiq/agent/react_agent/agent.py +146 -112
- aiq/agent/react_agent/prompt.py +1 -6
- aiq/agent/react_agent/register.py +36 -35
- aiq/agent/rewoo_agent/agent.py +36 -35
- aiq/agent/rewoo_agent/register.py +2 -2
- aiq/agent/tool_calling_agent/agent.py +3 -7
- aiq/agent/tool_calling_agent/register.py +1 -1
- aiq/authentication/__init__.py +14 -0
- aiq/authentication/api_key/__init__.py +14 -0
- aiq/authentication/api_key/api_key_auth_provider.py +92 -0
- aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
- aiq/authentication/api_key/register.py +26 -0
- aiq/authentication/exceptions/__init__.py +14 -0
- aiq/authentication/exceptions/api_key_exceptions.py +38 -0
- aiq/authentication/exceptions/auth_code_grant_exceptions.py +86 -0
- aiq/authentication/exceptions/call_back_exceptions.py +38 -0
- aiq/authentication/exceptions/request_exceptions.py +54 -0
- aiq/authentication/http_basic_auth/__init__.py +0 -0
- aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- aiq/authentication/http_basic_auth/register.py +30 -0
- aiq/authentication/interfaces.py +93 -0
- aiq/authentication/oauth2/__init__.py +14 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- aiq/authentication/oauth2/register.py +25 -0
- aiq/authentication/register.py +21 -0
- aiq/builder/builder.py +64 -2
- aiq/builder/component_utils.py +16 -3
- aiq/builder/context.py +37 -0
- aiq/builder/eval_builder.py +43 -2
- aiq/builder/function.py +44 -12
- aiq/builder/function_base.py +1 -1
- aiq/builder/intermediate_step_manager.py +6 -8
- aiq/builder/user_interaction_manager.py +3 -0
- aiq/builder/workflow.py +23 -18
- aiq/builder/workflow_builder.py +421 -61
- aiq/cli/commands/info/list_mcp.py +103 -16
- aiq/cli/commands/sizing/__init__.py +14 -0
- aiq/cli/commands/sizing/calc.py +294 -0
- aiq/cli/commands/sizing/sizing.py +27 -0
- aiq/cli/commands/start.py +2 -1
- aiq/cli/entrypoint.py +2 -0
- aiq/cli/register_workflow.py +80 -0
- aiq/cli/type_registry.py +151 -30
- aiq/data_models/api_server.py +124 -12
- aiq/data_models/authentication.py +231 -0
- aiq/data_models/common.py +35 -7
- aiq/data_models/component.py +17 -9
- aiq/data_models/component_ref.py +33 -0
- aiq/data_models/config.py +60 -3
- aiq/data_models/dataset_handler.py +2 -1
- aiq/data_models/embedder.py +1 -0
- aiq/data_models/evaluate.py +23 -0
- aiq/data_models/function_dependencies.py +8 -0
- aiq/data_models/interactive.py +10 -1
- aiq/data_models/intermediate_step.py +38 -5
- aiq/data_models/its_strategy.py +30 -0
- aiq/data_models/llm.py +1 -0
- aiq/data_models/memory.py +1 -0
- aiq/data_models/object_store.py +44 -0
- aiq/data_models/profiler.py +1 -0
- aiq/data_models/retry_mixin.py +35 -0
- aiq/data_models/span.py +187 -0
- aiq/data_models/telemetry_exporter.py +2 -2
- aiq/embedder/nim_embedder.py +2 -1
- aiq/embedder/openai_embedder.py +2 -1
- aiq/eval/config.py +19 -1
- aiq/eval/dataset_handler/dataset_handler.py +87 -2
- aiq/eval/evaluate.py +208 -27
- aiq/eval/evaluator/base_evaluator.py +73 -0
- aiq/eval/evaluator/evaluator_model.py +1 -0
- aiq/eval/intermediate_step_adapter.py +11 -5
- aiq/eval/rag_evaluator/evaluate.py +55 -15
- aiq/eval/rag_evaluator/register.py +6 -1
- aiq/eval/remote_workflow.py +7 -2
- aiq/eval/runners/__init__.py +14 -0
- aiq/eval/runners/config.py +39 -0
- aiq/eval/runners/multi_eval_runner.py +54 -0
- aiq/eval/trajectory_evaluator/evaluate.py +22 -65
- aiq/eval/tunable_rag_evaluator/evaluate.py +150 -168
- aiq/eval/tunable_rag_evaluator/register.py +2 -0
- aiq/eval/usage_stats.py +41 -0
- aiq/eval/utils/output_uploader.py +10 -1
- aiq/eval/utils/weave_eval.py +184 -0
- aiq/experimental/__init__.py +0 -0
- aiq/experimental/decorators/__init__.py +0 -0
- aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
- aiq/experimental/inference_time_scaling/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/editing/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/editing/iterative_plan_refinement_editor.py +147 -0
- aiq/experimental/inference_time_scaling/editing/llm_as_a_judge_editor.py +204 -0
- aiq/experimental/inference_time_scaling/editing/motivation_aware_summarization.py +107 -0
- aiq/experimental/inference_time_scaling/functions/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/functions/execute_score_select_function.py +105 -0
- aiq/experimental/inference_time_scaling/functions/its_tool_orchestration_function.py +205 -0
- aiq/experimental/inference_time_scaling/functions/its_tool_wrapper_function.py +146 -0
- aiq/experimental/inference_time_scaling/functions/plan_select_execute_function.py +224 -0
- aiq/experimental/inference_time_scaling/models/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/models/editor_config.py +132 -0
- aiq/experimental/inference_time_scaling/models/its_item.py +48 -0
- aiq/experimental/inference_time_scaling/models/scoring_config.py +112 -0
- aiq/experimental/inference_time_scaling/models/search_config.py +120 -0
- aiq/experimental/inference_time_scaling/models/selection_config.py +154 -0
- aiq/experimental/inference_time_scaling/models/stage_enums.py +43 -0
- aiq/experimental/inference_time_scaling/models/strategy_base.py +66 -0
- aiq/experimental/inference_time_scaling/models/tool_use_config.py +41 -0
- aiq/experimental/inference_time_scaling/register.py +36 -0
- aiq/experimental/inference_time_scaling/scoring/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/scoring/llm_based_agent_scorer.py +168 -0
- aiq/experimental/inference_time_scaling/scoring/llm_based_plan_scorer.py +168 -0
- aiq/experimental/inference_time_scaling/scoring/motivation_aware_scorer.py +111 -0
- aiq/experimental/inference_time_scaling/search/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/search/multi_llm_planner.py +128 -0
- aiq/experimental/inference_time_scaling/search/multi_query_retrieval_search.py +122 -0
- aiq/experimental/inference_time_scaling/search/single_shot_multi_plan_planner.py +128 -0
- aiq/experimental/inference_time_scaling/selection/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/selection/best_of_n_selector.py +63 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_agent_output_selector.py +131 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_output_merging_selector.py +159 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_plan_selector.py +128 -0
- aiq/experimental/inference_time_scaling/selection/threshold_selector.py +58 -0
- aiq/front_ends/console/authentication_flow_handler.py +233 -0
- aiq/front_ends/console/console_front_end_plugin.py +11 -2
- aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +93 -9
- aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +14 -1
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +537 -52
- aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
- aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- aiq/front_ends/fastapi/job_store.py +47 -25
- aiq/front_ends/fastapi/main.py +2 -0
- aiq/front_ends/fastapi/message_handler.py +108 -89
- aiq/front_ends/fastapi/step_adaptor.py +2 -1
- aiq/llm/aws_bedrock_llm.py +57 -0
- aiq/llm/nim_llm.py +2 -1
- aiq/llm/openai_llm.py +3 -2
- aiq/llm/register.py +1 -0
- aiq/meta/pypi.md +12 -12
- aiq/object_store/__init__.py +20 -0
- aiq/object_store/in_memory_object_store.py +74 -0
- aiq/object_store/interfaces.py +84 -0
- aiq/object_store/models.py +36 -0
- aiq/object_store/register.py +20 -0
- aiq/observability/__init__.py +14 -0
- aiq/observability/exporter/__init__.py +14 -0
- aiq/observability/exporter/base_exporter.py +449 -0
- aiq/observability/exporter/exporter.py +78 -0
- aiq/observability/exporter/file_exporter.py +33 -0
- aiq/observability/exporter/processing_exporter.py +269 -0
- aiq/observability/exporter/raw_exporter.py +52 -0
- aiq/observability/exporter/span_exporter.py +264 -0
- aiq/observability/exporter_manager.py +335 -0
- aiq/observability/mixin/__init__.py +14 -0
- aiq/observability/mixin/batch_config_mixin.py +26 -0
- aiq/observability/mixin/collector_config_mixin.py +23 -0
- aiq/observability/mixin/file_mixin.py +288 -0
- aiq/observability/mixin/file_mode.py +23 -0
- aiq/observability/mixin/resource_conflict_mixin.py +134 -0
- aiq/observability/mixin/serialize_mixin.py +61 -0
- aiq/observability/mixin/type_introspection_mixin.py +183 -0
- aiq/observability/processor/__init__.py +14 -0
- aiq/observability/processor/batching_processor.py +316 -0
- aiq/observability/processor/intermediate_step_serializer.py +28 -0
- aiq/observability/processor/processor.py +68 -0
- aiq/observability/register.py +36 -39
- aiq/observability/utils/__init__.py +14 -0
- aiq/observability/utils/dict_utils.py +236 -0
- aiq/observability/utils/time_utils.py +31 -0
- aiq/profiler/calc/__init__.py +14 -0
- aiq/profiler/calc/calc_runner.py +623 -0
- aiq/profiler/calc/calculations.py +288 -0
- aiq/profiler/calc/data_models.py +176 -0
- aiq/profiler/calc/plot.py +345 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +22 -10
- aiq/profiler/data_models.py +24 -0
- aiq/profiler/inference_metrics_model.py +3 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +8 -0
- aiq/profiler/inference_optimization/data_models.py +2 -2
- aiq/profiler/inference_optimization/llm_metrics.py +2 -2
- aiq/profiler/profile_runner.py +61 -21
- aiq/runtime/loader.py +9 -3
- aiq/runtime/runner.py +23 -9
- aiq/runtime/session.py +25 -7
- aiq/runtime/user_metadata.py +2 -3
- aiq/tool/chat_completion.py +74 -0
- aiq/tool/code_execution/README.md +152 -0
- aiq/tool/code_execution/code_sandbox.py +151 -72
- aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +139 -24
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +3 -1
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +27 -2
- aiq/tool/code_execution/register.py +7 -3
- aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
- aiq/tool/mcp/exceptions.py +142 -0
- aiq/tool/mcp/mcp_client.py +41 -6
- aiq/tool/mcp/mcp_tool.py +3 -2
- aiq/tool/register.py +1 -0
- aiq/tool/server_tools.py +6 -3
- aiq/utils/exception_handlers/automatic_retries.py +289 -0
- aiq/utils/exception_handlers/mcp.py +211 -0
- aiq/utils/io/model_processing.py +28 -0
- aiq/utils/log_utils.py +37 -0
- aiq/utils/string_utils.py +38 -0
- aiq/utils/type_converter.py +18 -2
- aiq/utils/type_utils.py +87 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/METADATA +53 -21
- aiqtoolkit-1.2.0rc2.dist-info/RECORD +436 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/WHEEL +1 -1
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/entry_points.txt +3 -0
- aiq/front_ends/fastapi/websocket.py +0 -148
- aiq/observability/async_otel_listener.py +0 -429
- aiqtoolkit-1.2.0.dev0.dist-info/RECORD +0 -316
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
17
|
|
|
18
|
+
from pydantic import AliasChoices
|
|
18
19
|
from pydantic import Field
|
|
19
20
|
|
|
20
21
|
from aiq.agent.base import AGENT_LOG_PREFIX
|
|
@@ -42,11 +43,24 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
|
|
|
42
43
|
description="The list of tools to provide to the react agent.")
|
|
43
44
|
llm_name: LLMRef = Field(description="The LLM model to use with the react agent.")
|
|
44
45
|
verbose: bool = Field(default=False, description="Set the verbosity of the react agent's logging.")
|
|
45
|
-
|
|
46
|
-
|
|
46
|
+
retry_agent_response_parsing_errors: bool = Field(
|
|
47
|
+
default=True,
|
|
48
|
+
validation_alias=AliasChoices("retry_agent_response_parsing_errors", "retry_parsing_errors"),
|
|
49
|
+
description="Whether to retry when encountering parsing errors in the agent's response.")
|
|
50
|
+
parse_agent_response_max_retries: int = Field(
|
|
51
|
+
default=1,
|
|
52
|
+
validation_alias=AliasChoices("parse_agent_response_max_retries", "max_retries"),
|
|
53
|
+
description="Maximum number of times the Agent may retry parsing errors. "
|
|
54
|
+
"Prevents the Agent from getting into infinite hallucination loops.")
|
|
55
|
+
tool_call_max_retries: int = Field(default=1, description="The number of retries before raising a tool call error.")
|
|
56
|
+
max_tool_calls: int = Field(default=15,
|
|
57
|
+
validation_alias=AliasChoices("max_tool_calls", "max_iterations"),
|
|
58
|
+
description="Maximum number of tool calls before stopping the agent.")
|
|
59
|
+
pass_tool_call_errors_to_agent: bool = Field(
|
|
60
|
+
default=True,
|
|
61
|
+
description="Whether to pass tool call errors to agent. If False, failed tool calls will raise an exception.")
|
|
47
62
|
include_tool_input_schema_in_tool_description: bool = Field(
|
|
48
63
|
default=True, description="Specify inclusion of tool input schemas in the prompt.")
|
|
49
|
-
max_iterations: int = Field(default=15, description="Number of tool calls before stoping the react agent.")
|
|
50
64
|
description: str = Field(default="ReAct Agent Workflow", description="The description of this functions use.")
|
|
51
65
|
system_prompt: str | None = Field(
|
|
52
66
|
default=None,
|
|
@@ -63,29 +77,13 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
|
|
|
63
77
|
async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builder):
|
|
64
78
|
from langchain.schema import BaseMessage
|
|
65
79
|
from langchain_core.messages import trim_messages
|
|
66
|
-
from langchain_core.prompts import ChatPromptTemplate
|
|
67
|
-
from langchain_core.prompts import MessagesPlaceholder
|
|
68
80
|
from langgraph.graph.graph import CompiledGraph
|
|
69
81
|
|
|
70
|
-
from aiq.agent.react_agent.
|
|
71
|
-
|
|
72
|
-
from .agent import
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
# the ReAct Agent prompt comes from prompt.py, and can be customized there or via config option system_prompt.
|
|
77
|
-
if config.system_prompt:
|
|
78
|
-
_prompt_str = config.system_prompt
|
|
79
|
-
if config.additional_instructions:
|
|
80
|
-
_prompt_str += f" {config.additional_instructions}"
|
|
81
|
-
valid_prompt = ReActAgentGraph.validate_system_prompt(config.system_prompt)
|
|
82
|
-
if not valid_prompt:
|
|
83
|
-
logger.exception("%s Invalid system_prompt", AGENT_LOG_PREFIX)
|
|
84
|
-
raise ValueError("Invalid system_prompt")
|
|
85
|
-
prompt = ChatPromptTemplate([("system", config.system_prompt), ("user", USER_PROMPT),
|
|
86
|
-
MessagesPlaceholder(variable_name='agent_scratchpad', optional=True)])
|
|
87
|
-
else:
|
|
88
|
-
prompt = react_agent_prompt
|
|
82
|
+
from aiq.agent.react_agent.agent import ReActAgentGraph
|
|
83
|
+
from aiq.agent.react_agent.agent import ReActGraphState
|
|
84
|
+
from aiq.agent.react_agent.agent import create_react_agent_prompt
|
|
85
|
+
|
|
86
|
+
prompt = create_react_agent_prompt(config)
|
|
89
87
|
|
|
90
88
|
# we can choose an LLM for the ReAct agent in the config file
|
|
91
89
|
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
@@ -96,13 +94,16 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
96
94
|
raise ValueError(f"No tools specified for ReAct Agent '{config.llm_name}'")
|
|
97
95
|
# configure callbacks, for sending intermediate steps
|
|
98
96
|
# construct the ReAct Agent Graph from the configured llm, prompt, and tools
|
|
99
|
-
graph: CompiledGraph = await ReActAgentGraph(
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
97
|
+
graph: CompiledGraph = await ReActAgentGraph(
|
|
98
|
+
llm=llm,
|
|
99
|
+
prompt=prompt,
|
|
100
|
+
tools=tools,
|
|
101
|
+
use_tool_schema=config.include_tool_input_schema_in_tool_description,
|
|
102
|
+
detailed_logs=config.verbose,
|
|
103
|
+
retry_agent_response_parsing_errors=config.retry_agent_response_parsing_errors,
|
|
104
|
+
parse_agent_response_max_retries=config.parse_agent_response_max_retries,
|
|
105
|
+
tool_call_max_retries=config.tool_call_max_retries,
|
|
106
|
+
pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent).build_graph()
|
|
106
107
|
|
|
107
108
|
async def _response_fn(input_message: AIQChatRequest) -> AIQChatResponse:
|
|
108
109
|
try:
|
|
@@ -117,7 +118,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
117
118
|
state = ReActGraphState(messages=messages)
|
|
118
119
|
|
|
119
120
|
# run the ReAct Agent Graph
|
|
120
|
-
state = await graph.ainvoke(state, config={'recursion_limit': (config.
|
|
121
|
+
state = await graph.ainvoke(state, config={'recursion_limit': (config.max_tool_calls + 1) * 2})
|
|
121
122
|
# setting recursion_limit: 4 allows 1 tool call
|
|
122
123
|
# - allows the ReAct Agent to perform 1 cycle / call 1 single tool,
|
|
123
124
|
# - but stops the agent when it tries to call a tool a second time
|
|
@@ -125,7 +126,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
125
126
|
# get and return the output from the state
|
|
126
127
|
state = ReActGraphState(**state)
|
|
127
128
|
output_message = state.messages[-1] # pylint: disable=E1136
|
|
128
|
-
return AIQChatResponse.from_string(output_message.content)
|
|
129
|
+
return AIQChatResponse.from_string(str(output_message.content))
|
|
129
130
|
|
|
130
131
|
except Exception as ex:
|
|
131
132
|
logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
|
|
@@ -139,10 +140,10 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
139
140
|
else:
|
|
140
141
|
|
|
141
142
|
async def _str_api_fn(input_message: str) -> str:
|
|
142
|
-
oai_input = GlobalTypeConverter.get().
|
|
143
|
+
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=AIQChatRequest)
|
|
143
144
|
|
|
144
145
|
oai_output = await _response_fn(oai_input)
|
|
145
146
|
|
|
146
|
-
return GlobalTypeConverter.get().
|
|
147
|
+
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
147
148
|
|
|
148
149
|
yield FunctionInfo.from_fn(_str_api_fn, description=config.description)
|
aiq/agent/rewoo_agent/agent.py
CHANGED
|
@@ -30,12 +30,11 @@ from langgraph.graph import StateGraph
|
|
|
30
30
|
from pydantic import BaseModel
|
|
31
31
|
from pydantic import Field
|
|
32
32
|
|
|
33
|
+
from aiq.agent.base import AGENT_CALL_LOG_MESSAGE
|
|
33
34
|
from aiq.agent.base import AGENT_LOG_PREFIX
|
|
34
|
-
from aiq.agent.base import AGENT_RESPONSE_LOG_MESSAGE
|
|
35
35
|
from aiq.agent.base import INPUT_SCHEMA_MESSAGE
|
|
36
36
|
from aiq.agent.base import NO_INPUT_ERROR_MESSAGE
|
|
37
37
|
from aiq.agent.base import TOOL_NOT_FOUND_ERROR_MESSAGE
|
|
38
|
-
from aiq.agent.base import TOOL_RESPONSE_LOG_MESSAGE
|
|
39
38
|
from aiq.agent.base import AgentDecision
|
|
40
39
|
from aiq.agent.base import BaseAgent
|
|
41
40
|
|
|
@@ -65,7 +64,7 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
65
64
|
solver_prompt: ChatPromptTemplate,
|
|
66
65
|
tools: list[BaseTool],
|
|
67
66
|
use_tool_schema: bool = True,
|
|
68
|
-
callbacks: list[AsyncCallbackHandler] = None,
|
|
67
|
+
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
69
68
|
detailed_logs: bool = False):
|
|
70
69
|
super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
|
|
71
70
|
|
|
@@ -91,7 +90,7 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
91
90
|
|
|
92
91
|
logger.debug("%s Initialized ReWOO Agent Graph", AGENT_LOG_PREFIX)
|
|
93
92
|
|
|
94
|
-
def _get_tool(self, tool_name):
|
|
93
|
+
def _get_tool(self, tool_name: str):
|
|
95
94
|
try:
|
|
96
95
|
return self.tools_dict.get(tool_name)
|
|
97
96
|
except Exception as ex:
|
|
@@ -180,22 +179,24 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
180
179
|
logger.debug("%s Starting the ReWOO Planner Node", AGENT_LOG_PREFIX)
|
|
181
180
|
|
|
182
181
|
planner = self.planner_prompt | self.llm
|
|
183
|
-
task = state.task.content
|
|
182
|
+
task = str(state.task.content)
|
|
184
183
|
if not task:
|
|
185
184
|
logger.error("%s No task provided to the ReWOO Agent. Please provide a valid task.", AGENT_LOG_PREFIX)
|
|
186
185
|
return {"result": NO_INPUT_ERROR_MESSAGE}
|
|
187
186
|
|
|
188
|
-
plan =
|
|
189
|
-
|
|
190
|
-
|
|
187
|
+
plan = await self._stream_llm(
|
|
188
|
+
planner,
|
|
189
|
+
{"task": task},
|
|
190
|
+
RunnableConfig(callbacks=self.callbacks) # type: ignore
|
|
191
|
+
)
|
|
191
192
|
|
|
192
|
-
steps = self._parse_planner_output(plan)
|
|
193
|
+
steps = self._parse_planner_output(str(plan.content))
|
|
193
194
|
|
|
194
195
|
if self.detailed_logs:
|
|
195
|
-
agent_response_log_message =
|
|
196
|
+
agent_response_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(plan.content))
|
|
196
197
|
logger.info("ReWOO agent planner output: %s", agent_response_log_message)
|
|
197
198
|
|
|
198
|
-
return {"plan":
|
|
199
|
+
return {"plan": plan, "steps": steps}
|
|
199
200
|
|
|
200
201
|
except Exception as ex:
|
|
201
202
|
logger.exception("%s Failed to call planner_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
|
|
@@ -213,10 +214,20 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
213
214
|
current_step)
|
|
214
215
|
raise RuntimeError(f"ReWOO Executor is invoked with an invalid step number: {current_step}")
|
|
215
216
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
217
|
+
steps_content = state.steps.content
|
|
218
|
+
if isinstance(steps_content, list) and current_step < len(steps_content):
|
|
219
|
+
step = steps_content[current_step]
|
|
220
|
+
if isinstance(step, dict) and "evidence" in step:
|
|
221
|
+
step_info = step["evidence"]
|
|
222
|
+
placeholder = step_info.get("placeholder", "")
|
|
223
|
+
tool = step_info.get("tool", "")
|
|
224
|
+
tool_input = step_info.get("tool_input", "")
|
|
225
|
+
else:
|
|
226
|
+
logger.error("%s Invalid step format at index %s", AGENT_LOG_PREFIX, current_step)
|
|
227
|
+
return {"intermediate_results": state.intermediate_results}
|
|
228
|
+
else:
|
|
229
|
+
logger.error("%s Invalid steps content or index %s", AGENT_LOG_PREFIX, current_step)
|
|
230
|
+
return {"intermediate_results": state.intermediate_results}
|
|
220
231
|
|
|
221
232
|
intermediate_results = state.intermediate_results
|
|
222
233
|
|
|
@@ -250,12 +261,10 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
250
261
|
|
|
251
262
|
# Run the tool. Try to use structured input, if possible
|
|
252
263
|
tool_input_parsed = self._parse_tool_input(tool_input)
|
|
253
|
-
tool_response = await
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
if tool_response is None or tool_response == "":
|
|
258
|
-
tool_response = "The tool provided an empty response.\n"
|
|
264
|
+
tool_response = await self._call_tool(requested_tool,
|
|
265
|
+
tool_input_parsed,
|
|
266
|
+
RunnableConfig(callbacks=self.callbacks),
|
|
267
|
+
max_retries=3)
|
|
259
268
|
|
|
260
269
|
# ToolMessage only accepts str or list[str | dict] as content.
|
|
261
270
|
# Convert into list if the response is a dict.
|
|
@@ -264,15 +273,8 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
264
273
|
|
|
265
274
|
tool_response_message = ToolMessage(name=tool, tool_call_id=tool, content=tool_response)
|
|
266
275
|
|
|
267
|
-
logger.debug("%s Successfully called the tool", AGENT_LOG_PREFIX)
|
|
268
276
|
if self.detailed_logs:
|
|
269
|
-
|
|
270
|
-
tool_response_str = tool_response_message.content
|
|
271
|
-
tool_response_str = tool_response_str[:1000] + "..." if len(
|
|
272
|
-
tool_response_str) > 1000 else tool_response_str
|
|
273
|
-
tool_response_log_message = TOOL_RESPONSE_LOG_MESSAGE % (
|
|
274
|
-
requested_tool.name, tool_input_parsed, tool_response_str)
|
|
275
|
-
logger.info("ReWOO agent executor output: %s", tool_response_log_message)
|
|
277
|
+
self._log_tool_response(requested_tool.name, tool_input_parsed, str(tool_response))
|
|
276
278
|
|
|
277
279
|
intermediate_results[placeholder] = tool_response_message
|
|
278
280
|
return {"intermediate_results": intermediate_results}
|
|
@@ -308,16 +310,15 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
308
310
|
tool = step_info.get("tool")
|
|
309
311
|
plan += f"Plan: {_plan}\n{placeholder} = {tool}[{tool_input}]"
|
|
310
312
|
|
|
311
|
-
task = state.task.content
|
|
313
|
+
task = str(state.task.content)
|
|
312
314
|
solver_prompt = self.solver_prompt.partial(plan=plan)
|
|
313
315
|
solver = solver_prompt | self.llm
|
|
314
|
-
output_message = ""
|
|
315
|
-
async for event in solver.astream({"task": task}, config=RunnableConfig(callbacks=self.callbacks)):
|
|
316
|
-
output_message += event.content
|
|
317
316
|
|
|
318
|
-
output_message =
|
|
317
|
+
output_message = await self._stream_llm(solver, {"task": task},
|
|
318
|
+
RunnableConfig(callbacks=self.callbacks)) # type: ignore
|
|
319
|
+
|
|
319
320
|
if self.detailed_logs:
|
|
320
|
-
solver_output_log_message =
|
|
321
|
+
solver_output_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(output_message.content))
|
|
321
322
|
logger.info("ReWOO agent solver output: %s", solver_output_log_message)
|
|
322
323
|
|
|
323
324
|
return {"result": output_message}
|
|
@@ -150,9 +150,9 @@ async def ReWOO_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
150
150
|
else:
|
|
151
151
|
|
|
152
152
|
async def _str_api_fn(input_message: str) -> str:
|
|
153
|
-
oai_input = GlobalTypeConverter.get().
|
|
153
|
+
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=AIQChatRequest)
|
|
154
154
|
oai_output = await _response_fn(oai_input)
|
|
155
155
|
|
|
156
|
-
return GlobalTypeConverter.get().
|
|
156
|
+
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
157
157
|
|
|
158
158
|
yield FunctionInfo.from_fn(_str_api_fn, description=config.description)
|
|
@@ -25,9 +25,8 @@ from langgraph.prebuilt import ToolNode
|
|
|
25
25
|
from pydantic import BaseModel
|
|
26
26
|
from pydantic import Field
|
|
27
27
|
|
|
28
|
+
from aiq.agent.base import AGENT_CALL_LOG_MESSAGE
|
|
28
29
|
from aiq.agent.base import AGENT_LOG_PREFIX
|
|
29
|
-
from aiq.agent.base import AGENT_RESPONSE_LOG_MESSAGE
|
|
30
|
-
from aiq.agent.base import TOOL_RESPONSE_LOG_MESSAGE
|
|
31
30
|
from aiq.agent.base import AgentDecision
|
|
32
31
|
from aiq.agent.dual_node import DualNodeAgent
|
|
33
32
|
|
|
@@ -62,7 +61,7 @@ class ToolCallAgentGraph(DualNodeAgent):
|
|
|
62
61
|
response = await self.llm.ainvoke(state.messages, config=RunnableConfig(callbacks=self.callbacks))
|
|
63
62
|
if self.detailed_logs:
|
|
64
63
|
agent_input = "\n".join(str(message.content) for message in state.messages)
|
|
65
|
-
logger.info(
|
|
64
|
+
logger.info(AGENT_CALL_LOG_MESSAGE, agent_input, response)
|
|
66
65
|
|
|
67
66
|
state.messages += [response]
|
|
68
67
|
return state
|
|
@@ -102,10 +101,7 @@ class ToolCallAgentGraph(DualNodeAgent):
|
|
|
102
101
|
|
|
103
102
|
for response in tool_response.get('messages'):
|
|
104
103
|
if self.detailed_logs:
|
|
105
|
-
|
|
106
|
-
response.content = response.content[:1000] + "..." if len(
|
|
107
|
-
response.content) > 1000 else response.content
|
|
108
|
-
logger.info(TOOL_RESPONSE_LOG_MESSAGE, tools, tool_input, response.content)
|
|
104
|
+
self._log_tool_response(str(tools), str(tool_input), response.content)
|
|
109
105
|
state.messages += [response]
|
|
110
106
|
|
|
111
107
|
return state
|
|
@@ -38,7 +38,7 @@ class ToolCallAgentWorkflowConfig(FunctionBaseConfig, name="tool_calling_agent")
|
|
|
38
38
|
tool_names: list[FunctionRef] = Field(default_factory=list,
|
|
39
39
|
description="The list of tools to provide to the tool calling agent.")
|
|
40
40
|
llm_name: LLMRef = Field(description="The LLM model to use with the tool calling agent.")
|
|
41
|
-
verbose: bool = Field(default=False, description="Set the verbosity of the
|
|
41
|
+
verbose: bool = Field(default=False, description="Set the verbosity of the tool calling agent's logging.")
|
|
42
42
|
handle_tool_errors: bool = Field(default=True, description="Specify ability to handle tool calling errors.")
|
|
43
43
|
description: str = Field(default="Tool Calling Agent Workflow", description="Description of this functions use.")
|
|
44
44
|
max_iterations: int = Field(default=15, description="Number of tool calls before stoping the tool calling agent.")
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import SecretStr
|
|
19
|
+
|
|
20
|
+
from aiq.authentication.api_key.api_key_auth_provider_config import APIKeyAuthProviderConfig
|
|
21
|
+
from aiq.authentication.interfaces import AuthProviderBase
|
|
22
|
+
from aiq.data_models.authentication import AuthResult
|
|
23
|
+
from aiq.data_models.authentication import BearerTokenCred
|
|
24
|
+
from aiq.data_models.authentication import HeaderAuthScheme
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class APIKeyAuthProvider(AuthProviderBase[APIKeyAuthProviderConfig]):
|
|
30
|
+
|
|
31
|
+
def __init__(self, config: APIKeyAuthProviderConfig, config_name: str | None = None) -> None:
|
|
32
|
+
assert isinstance(config, APIKeyAuthProviderConfig), ("Config is not APIKeyConfig")
|
|
33
|
+
super().__init__(config)
|
|
34
|
+
|
|
35
|
+
async def _construct_authentication_header(self) -> BearerTokenCred:
|
|
36
|
+
"""
|
|
37
|
+
Constructs the authenticated HTTP header based on the authentication scheme.
|
|
38
|
+
Basic Authentication follows the OpenAPI 3.0 Basic Authentication standard as well as RFC 7617.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
header_auth_scheme (HeaderAuthScheme): The HTTP authentication scheme to use.
|
|
42
|
+
Supported schemes: BEARER, X_API_KEY, BASIC, CUSTOM.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
BearerTokenCred: The HTTP headers containing the authentication credentials.
|
|
46
|
+
Returns None if the scheme is not supported or configuration is invalid.
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
from aiq.authentication.interfaces import AUTHORIZATION_HEADER
|
|
51
|
+
|
|
52
|
+
config: APIKeyAuthProviderConfig = self.config
|
|
53
|
+
|
|
54
|
+
header_auth_scheme = config.auth_scheme
|
|
55
|
+
|
|
56
|
+
if header_auth_scheme == HeaderAuthScheme.BEARER:
|
|
57
|
+
return BearerTokenCred(token=SecretStr(f"{config.raw_key}"),
|
|
58
|
+
scheme=HeaderAuthScheme.BEARER.value,
|
|
59
|
+
header_name=AUTHORIZATION_HEADER)
|
|
60
|
+
|
|
61
|
+
if header_auth_scheme == HeaderAuthScheme.X_API_KEY:
|
|
62
|
+
return BearerTokenCred(token=SecretStr(f"{config.raw_key}"),
|
|
63
|
+
scheme=HeaderAuthScheme.X_API_KEY.value,
|
|
64
|
+
header_name='')
|
|
65
|
+
|
|
66
|
+
if header_auth_scheme == HeaderAuthScheme.CUSTOM:
|
|
67
|
+
if not config.custom_header_name:
|
|
68
|
+
raise ValueError('custom_header_name required when using header_auth_scheme=CUSTOM')
|
|
69
|
+
|
|
70
|
+
if not config.custom_header_prefix:
|
|
71
|
+
raise ValueError('custom_header_prefix required when using header_auth_scheme=CUSTOM')
|
|
72
|
+
|
|
73
|
+
return BearerTokenCred(token=SecretStr(f"{config.raw_key}"),
|
|
74
|
+
scheme=config.custom_header_prefix,
|
|
75
|
+
header_name=config.custom_header_name)
|
|
76
|
+
|
|
77
|
+
raise ValueError(f"Unsupported header auth scheme: {header_auth_scheme}")
|
|
78
|
+
|
|
79
|
+
async def authenticate(self, user_id: str | None = None) -> AuthResult | None:
|
|
80
|
+
"""
|
|
81
|
+
Authenticate the user using the API key credentials.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
user_id (str): The user ID to authenticate.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
AuthenticatedContext: The authenticated context containing headers, query params, cookies, etc.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
headers = await self._construct_authentication_header()
|
|
91
|
+
|
|
92
|
+
return AuthResult(credentials=[headers])
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import re
|
|
18
|
+
import string
|
|
19
|
+
|
|
20
|
+
from pydantic import Field
|
|
21
|
+
from pydantic import field_validator
|
|
22
|
+
|
|
23
|
+
from aiq.authentication.exceptions.api_key_exceptions import APIKeyFieldError
|
|
24
|
+
from aiq.authentication.exceptions.api_key_exceptions import HeaderNameFieldError
|
|
25
|
+
from aiq.authentication.exceptions.api_key_exceptions import HeaderPrefixFieldError
|
|
26
|
+
from aiq.data_models.authentication import AuthProviderBaseConfig
|
|
27
|
+
from aiq.data_models.authentication import HeaderAuthScheme
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
# Strict RFC 7230 compliant header name regex
|
|
32
|
+
HEADER_NAME_REGEX = re.compile(r"^[!#$%&'*+\-.^_`|~0-9a-zA-Z]+$")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class APIKeyAuthProviderConfig(AuthProviderBaseConfig, name="api_key"):
|
|
36
|
+
"""
|
|
37
|
+
API Key authentication configuration model.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
raw_key: str = Field(description=("Raw API token or credential to be injected into the request parameter. "
|
|
41
|
+
"Used for 'bearer','x-api-key','custom', and other schemes. "))
|
|
42
|
+
|
|
43
|
+
auth_scheme: HeaderAuthScheme = Field(default=HeaderAuthScheme.BEARER,
|
|
44
|
+
description=("The HTTP authentication scheme to use. "
|
|
45
|
+
"Supported schemes: BEARER, X_API_KEY, BASIC, CUSTOM."))
|
|
46
|
+
|
|
47
|
+
custom_header_name: str | None = Field(description="The HTTP header name that MUST be used in conjunction "
|
|
48
|
+
"with the custom_header_prefix when HeaderAuthScheme is CUSTOM.",
|
|
49
|
+
default=None)
|
|
50
|
+
custom_header_prefix: str | None = Field(description="The HTTP header prefix that MUST be used in conjunction "
|
|
51
|
+
"with the custom_header_name when HeaderAuthScheme is CUSTOM.",
|
|
52
|
+
default=None)
|
|
53
|
+
|
|
54
|
+
@field_validator('raw_key')
|
|
55
|
+
@classmethod
|
|
56
|
+
def validate_raw_key(cls, value: str) -> str:
|
|
57
|
+
if not value:
|
|
58
|
+
raise APIKeyFieldError('value_missing', 'raw_key field value is required.')
|
|
59
|
+
|
|
60
|
+
if len(value) < 8:
|
|
61
|
+
raise APIKeyFieldError(
|
|
62
|
+
'value_too_short',
|
|
63
|
+
'raw_key field value must be at least 8 characters long for security. '
|
|
64
|
+
f'Got: {len(value)} characters.')
|
|
65
|
+
|
|
66
|
+
if len(value.strip()) != len(value):
|
|
67
|
+
raise APIKeyFieldError('whitespace_found',
|
|
68
|
+
'raw_key field value cannot have leading or trailing whitespace.')
|
|
69
|
+
|
|
70
|
+
if any(c in string.whitespace for c in value):
|
|
71
|
+
raise APIKeyFieldError('contains_whitespace', 'raw_key must not contain any '
|
|
72
|
+
'whitespace characters.')
|
|
73
|
+
|
|
74
|
+
return value
|
|
75
|
+
|
|
76
|
+
@field_validator('custom_header_name')
|
|
77
|
+
@classmethod
|
|
78
|
+
def validate_custom_header_name(cls, value: str) -> str:
|
|
79
|
+
if not value:
|
|
80
|
+
raise HeaderNameFieldError('value_missing', 'custom_header_name is required.')
|
|
81
|
+
|
|
82
|
+
if value != value.strip():
|
|
83
|
+
raise HeaderNameFieldError('whitespace_found',
|
|
84
|
+
'custom_header_name field value cannot have leading or trailing whitespace.')
|
|
85
|
+
|
|
86
|
+
if any(c in string.whitespace for c in value):
|
|
87
|
+
raise HeaderNameFieldError('contains_whitespace',
|
|
88
|
+
'custom_header_name must not contain any whitespace characters.')
|
|
89
|
+
|
|
90
|
+
if not HEADER_NAME_REGEX.fullmatch(value):
|
|
91
|
+
raise HeaderNameFieldError(
|
|
92
|
+
'invalid_format',
|
|
93
|
+
'custom_header_name must match the HTTP token syntax: ASCII letters, digits, or allowed symbols.')
|
|
94
|
+
|
|
95
|
+
return value
|
|
96
|
+
|
|
97
|
+
@field_validator('custom_header_prefix')
|
|
98
|
+
@classmethod
|
|
99
|
+
def validate_custom_header_prefix(cls, value: str) -> str:
|
|
100
|
+
if not value:
|
|
101
|
+
raise HeaderPrefixFieldError('value_missing', 'custom_header_prefix is required.')
|
|
102
|
+
|
|
103
|
+
if value != value.strip():
|
|
104
|
+
raise HeaderPrefixFieldError(
|
|
105
|
+
'whitespace_found', 'custom_header_prefix field value cannot have '
|
|
106
|
+
'leading or trailing whitespace.')
|
|
107
|
+
|
|
108
|
+
if any(c in string.whitespace for c in value):
|
|
109
|
+
raise HeaderPrefixFieldError('contains_whitespace',
|
|
110
|
+
'custom_header_prefix must not contain any whitespace characters.')
|
|
111
|
+
|
|
112
|
+
if not value.isascii():
|
|
113
|
+
raise HeaderPrefixFieldError('invalid_format', 'custom_header_prefix must be ASCII.')
|
|
114
|
+
|
|
115
|
+
return value
|
|
116
|
+
|
|
117
|
+
@field_validator('raw_key', mode='after')
|
|
118
|
+
@classmethod
|
|
119
|
+
def validate_raw_key_after(cls, value: str) -> str:
|
|
120
|
+
if not value:
|
|
121
|
+
raise APIKeyFieldError('value_missing', 'raw_key field value is '
|
|
122
|
+
'required after construction.')
|
|
123
|
+
|
|
124
|
+
return value
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from aiq.authentication.api_key.api_key_auth_provider_config import APIKeyAuthProviderConfig
|
|
17
|
+
from aiq.builder.builder import Builder
|
|
18
|
+
from aiq.cli.register_workflow import register_auth_provider
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register_auth_provider(config_type=APIKeyAuthProviderConfig)
|
|
22
|
+
async def api_key_client(config: APIKeyAuthProviderConfig, builder: Builder):
|
|
23
|
+
|
|
24
|
+
from aiq.authentication.api_key.api_key_auth_provider import APIKeyAuthProvider
|
|
25
|
+
|
|
26
|
+
yield APIKeyAuthProvider(config=config)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class APIKeyFieldError(Exception):
|
|
18
|
+
"""Raised when API Key Config api_key field validation fails unexpectedly."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, error_code: str, message: str, *args):
|
|
21
|
+
self.error_code = error_code
|
|
22
|
+
super().__init__(f"[{error_code}] {message}", *args)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class HeaderNameFieldError(Exception):
|
|
26
|
+
"""Raised when API Key Config header_name field validation fails unexpectedly."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, error_code: str, message: str, *args):
|
|
29
|
+
self.error_code = error_code
|
|
30
|
+
super().__init__(f"[{error_code}] {message}", *args)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class HeaderPrefixFieldError(Exception):
|
|
34
|
+
"""Raised when API Key Config header_prefix field validation fails unexpectedly."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, error_code: str, message: str, *args):
|
|
37
|
+
self.error_code = error_code
|
|
38
|
+
super().__init__(f"[{error_code}] {message}", *args)
|