aiqtoolkit 1.2.0a20250706__py3-none-any.whl → 1.2.0a20250730__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/base.py +171 -8
- aiq/agent/dual_node.py +1 -1
- aiq/agent/react_agent/agent.py +113 -113
- aiq/agent/react_agent/register.py +31 -14
- aiq/agent/rewoo_agent/agent.py +36 -35
- aiq/agent/rewoo_agent/register.py +2 -2
- aiq/agent/tool_calling_agent/agent.py +3 -7
- aiq/authentication/__init__.py +14 -0
- aiq/authentication/api_key/__init__.py +14 -0
- aiq/authentication/api_key/api_key_auth_provider.py +92 -0
- aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
- aiq/authentication/api_key/register.py +26 -0
- aiq/authentication/exceptions/__init__.py +14 -0
- aiq/authentication/exceptions/api_key_exceptions.py +38 -0
- aiq/authentication/exceptions/auth_code_grant_exceptions.py +86 -0
- aiq/authentication/exceptions/call_back_exceptions.py +38 -0
- aiq/authentication/exceptions/request_exceptions.py +54 -0
- aiq/authentication/http_basic_auth/__init__.py +0 -0
- aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- aiq/authentication/http_basic_auth/register.py +30 -0
- aiq/authentication/interfaces.py +93 -0
- aiq/authentication/oauth2/__init__.py +14 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- aiq/authentication/oauth2/register.py +25 -0
- aiq/authentication/register.py +21 -0
- aiq/builder/builder.py +64 -2
- aiq/builder/component_utils.py +16 -3
- aiq/builder/context.py +26 -0
- aiq/builder/eval_builder.py +43 -2
- aiq/builder/function.py +32 -4
- aiq/builder/function_base.py +1 -1
- aiq/builder/intermediate_step_manager.py +6 -8
- aiq/builder/user_interaction_manager.py +3 -0
- aiq/builder/workflow.py +23 -18
- aiq/builder/workflow_builder.py +420 -73
- aiq/cli/commands/info/list_mcp.py +103 -16
- aiq/cli/commands/sizing/__init__.py +14 -0
- aiq/cli/commands/sizing/calc.py +294 -0
- aiq/cli/commands/sizing/sizing.py +27 -0
- aiq/cli/commands/start.py +1 -0
- aiq/cli/entrypoint.py +2 -0
- aiq/cli/register_workflow.py +80 -0
- aiq/cli/type_registry.py +151 -30
- aiq/data_models/api_server.py +117 -11
- aiq/data_models/authentication.py +231 -0
- aiq/data_models/common.py +35 -7
- aiq/data_models/component.py +17 -9
- aiq/data_models/component_ref.py +33 -0
- aiq/data_models/config.py +60 -3
- aiq/data_models/embedder.py +1 -0
- aiq/data_models/function_dependencies.py +8 -0
- aiq/data_models/interactive.py +10 -1
- aiq/data_models/intermediate_step.py +15 -5
- aiq/data_models/its_strategy.py +30 -0
- aiq/data_models/llm.py +1 -0
- aiq/data_models/memory.py +1 -0
- aiq/data_models/object_store.py +44 -0
- aiq/data_models/retry_mixin.py +35 -0
- aiq/data_models/span.py +187 -0
- aiq/data_models/telemetry_exporter.py +2 -2
- aiq/embedder/nim_embedder.py +2 -1
- aiq/embedder/openai_embedder.py +2 -1
- aiq/eval/config.py +19 -1
- aiq/eval/dataset_handler/dataset_handler.py +75 -1
- aiq/eval/evaluate.py +53 -10
- aiq/eval/rag_evaluator/evaluate.py +23 -12
- aiq/eval/remote_workflow.py +7 -2
- aiq/eval/runners/__init__.py +14 -0
- aiq/eval/runners/config.py +39 -0
- aiq/eval/runners/multi_eval_runner.py +54 -0
- aiq/eval/usage_stats.py +6 -0
- aiq/eval/utils/weave_eval.py +5 -1
- aiq/experimental/__init__.py +0 -0
- aiq/experimental/decorators/__init__.py +0 -0
- aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
- aiq/experimental/inference_time_scaling/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/editing/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/editing/iterative_plan_refinement_editor.py +147 -0
- aiq/experimental/inference_time_scaling/editing/llm_as_a_judge_editor.py +204 -0
- aiq/experimental/inference_time_scaling/editing/motivation_aware_summarization.py +107 -0
- aiq/experimental/inference_time_scaling/functions/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/functions/execute_score_select_function.py +105 -0
- aiq/experimental/inference_time_scaling/functions/its_tool_orchestration_function.py +205 -0
- aiq/experimental/inference_time_scaling/functions/its_tool_wrapper_function.py +146 -0
- aiq/experimental/inference_time_scaling/functions/plan_select_execute_function.py +224 -0
- aiq/experimental/inference_time_scaling/models/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/models/editor_config.py +132 -0
- aiq/experimental/inference_time_scaling/models/its_item.py +48 -0
- aiq/experimental/inference_time_scaling/models/scoring_config.py +112 -0
- aiq/experimental/inference_time_scaling/models/search_config.py +120 -0
- aiq/experimental/inference_time_scaling/models/selection_config.py +154 -0
- aiq/experimental/inference_time_scaling/models/stage_enums.py +43 -0
- aiq/experimental/inference_time_scaling/models/strategy_base.py +66 -0
- aiq/experimental/inference_time_scaling/models/tool_use_config.py +41 -0
- aiq/experimental/inference_time_scaling/register.py +36 -0
- aiq/experimental/inference_time_scaling/scoring/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/scoring/llm_based_agent_scorer.py +168 -0
- aiq/experimental/inference_time_scaling/scoring/llm_based_plan_scorer.py +168 -0
- aiq/experimental/inference_time_scaling/scoring/motivation_aware_scorer.py +111 -0
- aiq/experimental/inference_time_scaling/search/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/search/multi_llm_planner.py +128 -0
- aiq/experimental/inference_time_scaling/search/multi_query_retrieval_search.py +122 -0
- aiq/experimental/inference_time_scaling/search/single_shot_multi_plan_planner.py +128 -0
- aiq/experimental/inference_time_scaling/selection/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/selection/best_of_n_selector.py +63 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_agent_output_selector.py +131 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_output_merging_selector.py +159 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_plan_selector.py +128 -0
- aiq/experimental/inference_time_scaling/selection/threshold_selector.py +58 -0
- aiq/front_ends/console/authentication_flow_handler.py +233 -0
- aiq/front_ends/console/console_front_end_plugin.py +11 -2
- aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +20 -0
- aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +14 -1
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +353 -31
- aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
- aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- aiq/front_ends/fastapi/main.py +2 -0
- aiq/front_ends/fastapi/message_handler.py +102 -84
- aiq/front_ends/fastapi/step_adaptor.py +2 -1
- aiq/llm/aws_bedrock_llm.py +2 -1
- aiq/llm/nim_llm.py +2 -1
- aiq/llm/openai_llm.py +2 -1
- aiq/object_store/__init__.py +20 -0
- aiq/object_store/in_memory_object_store.py +74 -0
- aiq/object_store/interfaces.py +84 -0
- aiq/object_store/models.py +36 -0
- aiq/object_store/register.py +20 -0
- aiq/observability/__init__.py +14 -0
- aiq/observability/exporter/__init__.py +14 -0
- aiq/observability/exporter/base_exporter.py +449 -0
- aiq/observability/exporter/exporter.py +78 -0
- aiq/observability/exporter/file_exporter.py +33 -0
- aiq/observability/exporter/processing_exporter.py +269 -0
- aiq/observability/exporter/raw_exporter.py +52 -0
- aiq/observability/exporter/span_exporter.py +264 -0
- aiq/observability/exporter_manager.py +335 -0
- aiq/observability/mixin/__init__.py +14 -0
- aiq/observability/mixin/batch_config_mixin.py +26 -0
- aiq/observability/mixin/collector_config_mixin.py +23 -0
- aiq/observability/mixin/file_mixin.py +288 -0
- aiq/observability/mixin/file_mode.py +23 -0
- aiq/observability/mixin/resource_conflict_mixin.py +134 -0
- aiq/observability/mixin/serialize_mixin.py +61 -0
- aiq/observability/mixin/type_introspection_mixin.py +183 -0
- aiq/observability/processor/__init__.py +14 -0
- aiq/observability/processor/batching_processor.py +316 -0
- aiq/observability/processor/intermediate_step_serializer.py +28 -0
- aiq/observability/processor/processor.py +68 -0
- aiq/observability/register.py +32 -116
- aiq/observability/utils/__init__.py +14 -0
- aiq/observability/utils/dict_utils.py +236 -0
- aiq/observability/utils/time_utils.py +31 -0
- aiq/profiler/calc/__init__.py +14 -0
- aiq/profiler/calc/calc_runner.py +623 -0
- aiq/profiler/calc/calculations.py +288 -0
- aiq/profiler/calc/data_models.py +176 -0
- aiq/profiler/calc/plot.py +345 -0
- aiq/profiler/data_models.py +2 -0
- aiq/profiler/profile_runner.py +16 -13
- aiq/runtime/loader.py +8 -2
- aiq/runtime/runner.py +23 -9
- aiq/runtime/session.py +16 -5
- aiq/tool/chat_completion.py +74 -0
- aiq/tool/code_execution/README.md +152 -0
- aiq/tool/code_execution/code_sandbox.py +151 -72
- aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +139 -24
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +3 -1
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +27 -2
- aiq/tool/code_execution/register.py +7 -3
- aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
- aiq/tool/mcp/exceptions.py +142 -0
- aiq/tool/mcp/mcp_client.py +17 -3
- aiq/tool/mcp/mcp_tool.py +1 -1
- aiq/tool/register.py +1 -0
- aiq/tool/server_tools.py +2 -2
- aiq/utils/exception_handlers/automatic_retries.py +289 -0
- aiq/utils/exception_handlers/mcp.py +211 -0
- aiq/utils/io/model_processing.py +28 -0
- aiq/utils/log_utils.py +37 -0
- aiq/utils/string_utils.py +38 -0
- aiq/utils/type_converter.py +18 -2
- aiq/utils/type_utils.py +87 -0
- {aiqtoolkit-1.2.0a20250706.dist-info → aiqtoolkit-1.2.0a20250730.dist-info}/METADATA +37 -9
- {aiqtoolkit-1.2.0a20250706.dist-info → aiqtoolkit-1.2.0a20250730.dist-info}/RECORD +195 -80
- {aiqtoolkit-1.2.0a20250706.dist-info → aiqtoolkit-1.2.0a20250730.dist-info}/entry_points.txt +3 -0
- aiq/front_ends/fastapi/websocket.py +0 -153
- aiq/observability/async_otel_listener.py +0 -470
- {aiqtoolkit-1.2.0a20250706.dist-info → aiqtoolkit-1.2.0a20250730.dist-info}/WHEEL +0 -0
- {aiqtoolkit-1.2.0a20250706.dist-info → aiqtoolkit-1.2.0a20250730.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {aiqtoolkit-1.2.0a20250706.dist-info → aiqtoolkit-1.2.0a20250730.dist-info}/licenses/LICENSE.md +0 -0
- {aiqtoolkit-1.2.0a20250706.dist-info → aiqtoolkit-1.2.0a20250730.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
import matplotlib.pyplot as plt
|
|
20
|
+
import numpy as np
|
|
21
|
+
import pandas as pd
|
|
22
|
+
|
|
23
|
+
from aiq.profiler.calc.data_models import LinearFitResult
|
|
24
|
+
from aiq.profiler.calc.data_models import SizingMetrics
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Plotting constants
|
|
30
|
+
class PlotConfig:
|
|
31
|
+
# Simple plot settings
|
|
32
|
+
SIMPLE_FIGSIZE = (12, 6)
|
|
33
|
+
SIMPLE_LINEWIDTH = 2
|
|
34
|
+
SIMPLE_DPI = 150
|
|
35
|
+
|
|
36
|
+
# Enhanced plot settings
|
|
37
|
+
ENHANCED_FIGSIZE = (16, 6)
|
|
38
|
+
ENHANCED_DPI = 300
|
|
39
|
+
|
|
40
|
+
# Marker and styling
|
|
41
|
+
DATA_MARKER = 'o'
|
|
42
|
+
OUTLIER_MARKER = 'x'
|
|
43
|
+
OUTLIER_COLOR = 'crimson'
|
|
44
|
+
TREND_COLOR = 'r'
|
|
45
|
+
TREND_LINESTYLE = '--'
|
|
46
|
+
TREND_ALPHA = 0.8
|
|
47
|
+
TREND_LINEWIDTH = 2.0
|
|
48
|
+
|
|
49
|
+
# Colors
|
|
50
|
+
LLM_LATENCY_COLOR = 'steelblue'
|
|
51
|
+
RUNTIME_COLOR = 'darkgreen'
|
|
52
|
+
SLA_COLOR = 'red'
|
|
53
|
+
NOTE_BOX_COLOR = 'mistyrose'
|
|
54
|
+
NOTE_TEXT_COLOR = 'crimson'
|
|
55
|
+
STATS_BOX_COLOR = 'lightblue'
|
|
56
|
+
|
|
57
|
+
# Alpha values
|
|
58
|
+
DATA_ALPHA = 0.7
|
|
59
|
+
OUTLIER_ALPHA = 0.9
|
|
60
|
+
GRID_ALPHA = 0.3
|
|
61
|
+
SLA_ALPHA = 0.7
|
|
62
|
+
NOTE_BOX_ALPHA = 0.7
|
|
63
|
+
STATS_BOX_ALPHA = 0.8
|
|
64
|
+
|
|
65
|
+
# Sizes
|
|
66
|
+
DATA_POINT_SIZE = 120
|
|
67
|
+
OUTLIER_POINT_SIZE = 140
|
|
68
|
+
DATA_LINEWIDTH = 1
|
|
69
|
+
|
|
70
|
+
# Font sizes
|
|
71
|
+
AXIS_LABEL_FONTSIZE = 12
|
|
72
|
+
TITLE_FONTSIZE = 14
|
|
73
|
+
LEGEND_FONTSIZE = 10
|
|
74
|
+
NOTE_FONTSIZE = 10
|
|
75
|
+
STATS_FONTSIZE = 10
|
|
76
|
+
|
|
77
|
+
# Text positioning
|
|
78
|
+
NOTE_X_POS = 0.98
|
|
79
|
+
NOTE_Y_POS = 0.02
|
|
80
|
+
STATS_X_POS = 0.02
|
|
81
|
+
STATS_Y_POS = 0.02
|
|
82
|
+
|
|
83
|
+
# Box styling
|
|
84
|
+
NOTE_BOX_PAD = 0.3
|
|
85
|
+
STATS_BOX_PAD = 0.5
|
|
86
|
+
|
|
87
|
+
# Trend line points
|
|
88
|
+
TREND_LINE_POINTS = 100
|
|
89
|
+
|
|
90
|
+
# Font weights
|
|
91
|
+
AXIS_LABEL_FONTWEIGHT = 'bold'
|
|
92
|
+
TITLE_FONTWEIGHT = 'bold'
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def plot_concurrency_vs_time_metrics_simple(df: pd.DataFrame, output_dir: Path) -> None:
|
|
96
|
+
"""
|
|
97
|
+
Save a simple plot of concurrency vs. p95 LLM latency and workflow runtime.
|
|
98
|
+
"""
|
|
99
|
+
plt.figure(figsize=PlotConfig.SIMPLE_FIGSIZE)
|
|
100
|
+
plt.plot(df["concurrency"],
|
|
101
|
+
df["llm_latency_p95"],
|
|
102
|
+
label="p95 LLM Latency (s)",
|
|
103
|
+
marker=PlotConfig.DATA_MARKER,
|
|
104
|
+
linewidth=PlotConfig.SIMPLE_LINEWIDTH)
|
|
105
|
+
plt.plot(df["concurrency"],
|
|
106
|
+
df["workflow_runtime_p95"],
|
|
107
|
+
label="p95 Workflow Runtime (s)",
|
|
108
|
+
marker="s",
|
|
109
|
+
linewidth=PlotConfig.SIMPLE_LINEWIDTH)
|
|
110
|
+
plt.xlabel("Concurrency")
|
|
111
|
+
plt.ylabel("Time (seconds)")
|
|
112
|
+
plt.title("Concurrency vs. p95 LLM Latency and Workflow Runtime")
|
|
113
|
+
plt.grid(True, alpha=PlotConfig.GRID_ALPHA)
|
|
114
|
+
plt.legend()
|
|
115
|
+
plt.tight_layout()
|
|
116
|
+
|
|
117
|
+
simple_plot_path = output_dir / "concurrency_vs_p95_simple.png"
|
|
118
|
+
plt.savefig(simple_plot_path, dpi=PlotConfig.SIMPLE_DPI, bbox_inches='tight')
|
|
119
|
+
plt.close()
|
|
120
|
+
logger.info("Simple plot saved to %s", simple_plot_path)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def plot_metric_vs_concurrency_with_optional_fit(
|
|
124
|
+
ax: plt.Axes,
|
|
125
|
+
x: np.ndarray,
|
|
126
|
+
y: np.ndarray,
|
|
127
|
+
metric_name: str,
|
|
128
|
+
y_label: str,
|
|
129
|
+
title: str,
|
|
130
|
+
color: str,
|
|
131
|
+
sla_value: float = 0.0,
|
|
132
|
+
sla_label: str = None,
|
|
133
|
+
fit: LinearFitResult | None = None,
|
|
134
|
+
):
|
|
135
|
+
"""
|
|
136
|
+
Helper to plot a metric vs concurrency with pre-computed fit, outlier highlighting, and SLA line.
|
|
137
|
+
Requires pre-computed fit to be provided.
|
|
138
|
+
"""
|
|
139
|
+
marker = PlotConfig.DATA_MARKER
|
|
140
|
+
outlier_marker = PlotConfig.OUTLIER_MARKER
|
|
141
|
+
outlier_color = PlotConfig.OUTLIER_COLOR
|
|
142
|
+
trend_color = PlotConfig.TREND_COLOR
|
|
143
|
+
trend_linestyle = PlotConfig.TREND_LINESTYLE
|
|
144
|
+
trend_alpha = PlotConfig.TREND_ALPHA
|
|
145
|
+
trend_linewidth = PlotConfig.TREND_LINEWIDTH
|
|
146
|
+
note_box_color = PlotConfig.NOTE_BOX_COLOR
|
|
147
|
+
note_text_color = PlotConfig.NOTE_TEXT_COLOR
|
|
148
|
+
legend_fontsize = PlotConfig.LEGEND_FONTSIZE
|
|
149
|
+
outliers_x = outliers_y = np.array([])
|
|
150
|
+
outliers_note = ""
|
|
151
|
+
|
|
152
|
+
# Skip analysis plot if no fit is available
|
|
153
|
+
if not fit:
|
|
154
|
+
logger.warning(f"No linear fit available for {metric_name}, skipping analysis plot")
|
|
155
|
+
return False
|
|
156
|
+
|
|
157
|
+
if fit.outliers_removed:
|
|
158
|
+
# Use the concurrencies that were removed to identify outlier points
|
|
159
|
+
outlier_mask = np.isin(x, fit.outliers_removed)
|
|
160
|
+
outliers_x = x[outlier_mask]
|
|
161
|
+
outliers_y = y[outlier_mask]
|
|
162
|
+
outliers_note = f"Outliers removed: concurrencies {fit.outliers_removed}"
|
|
163
|
+
# Plot cleaned data (points that weren't removed as outliers)
|
|
164
|
+
non_outlier_mask = ~np.isin(x, fit.outliers_removed)
|
|
165
|
+
x_clean = x[non_outlier_mask]
|
|
166
|
+
y_clean = y[non_outlier_mask]
|
|
167
|
+
ax.scatter(x_clean,
|
|
168
|
+
y_clean,
|
|
169
|
+
alpha=PlotConfig.DATA_ALPHA,
|
|
170
|
+
s=PlotConfig.DATA_POINT_SIZE,
|
|
171
|
+
c=color,
|
|
172
|
+
edgecolors='white',
|
|
173
|
+
linewidth=PlotConfig.DATA_LINEWIDTH,
|
|
174
|
+
marker=marker,
|
|
175
|
+
label='Data Points')
|
|
176
|
+
ax.scatter(outliers_x,
|
|
177
|
+
outliers_y,
|
|
178
|
+
alpha=PlotConfig.OUTLIER_ALPHA,
|
|
179
|
+
s=PlotConfig.OUTLIER_POINT_SIZE,
|
|
180
|
+
c=outlier_color,
|
|
181
|
+
marker=outlier_marker,
|
|
182
|
+
label='Removed Outliers')
|
|
183
|
+
else:
|
|
184
|
+
# No outliers plot all data points
|
|
185
|
+
ax.scatter(x,
|
|
186
|
+
y,
|
|
187
|
+
alpha=PlotConfig.DATA_ALPHA,
|
|
188
|
+
s=PlotConfig.DATA_POINT_SIZE,
|
|
189
|
+
c=color,
|
|
190
|
+
edgecolors='white',
|
|
191
|
+
linewidth=PlotConfig.DATA_LINEWIDTH,
|
|
192
|
+
marker=marker,
|
|
193
|
+
label='Data Points')
|
|
194
|
+
|
|
195
|
+
# Plot trend line using the fit
|
|
196
|
+
x_fit = np.linspace(x.min(), x.max(), PlotConfig.TREND_LINE_POINTS)
|
|
197
|
+
y_fit = fit.slope * x_fit + fit.intercept
|
|
198
|
+
ax.plot(x_fit,
|
|
199
|
+
y_fit,
|
|
200
|
+
trend_linestyle,
|
|
201
|
+
alpha=trend_alpha,
|
|
202
|
+
linewidth=trend_linewidth,
|
|
203
|
+
color=trend_color,
|
|
204
|
+
label=f'Trend (slope={fit.slope:.4f}, R²={fit.r_squared:.3f})')
|
|
205
|
+
|
|
206
|
+
if sla_value > 0:
|
|
207
|
+
ax.axhline(y=sla_value,
|
|
208
|
+
color=PlotConfig.SLA_COLOR,
|
|
209
|
+
linestyle=':',
|
|
210
|
+
alpha=PlotConfig.SLA_ALPHA,
|
|
211
|
+
linewidth=2,
|
|
212
|
+
label=sla_label or f'SLA Threshold ({sla_value}s)')
|
|
213
|
+
|
|
214
|
+
ax.set_xlabel('Concurrency', fontsize=PlotConfig.AXIS_LABEL_FONTSIZE, fontweight=PlotConfig.AXIS_LABEL_FONTWEIGHT)
|
|
215
|
+
ax.set_ylabel(y_label, fontsize=PlotConfig.AXIS_LABEL_FONTSIZE, fontweight=PlotConfig.AXIS_LABEL_FONTWEIGHT)
|
|
216
|
+
ax.set_title(title, fontsize=PlotConfig.TITLE_FONTSIZE, fontweight=PlotConfig.TITLE_FONTWEIGHT)
|
|
217
|
+
ax.grid(True, alpha=PlotConfig.GRID_ALPHA)
|
|
218
|
+
ax.legend(fontsize=legend_fontsize)
|
|
219
|
+
if outliers_note:
|
|
220
|
+
ax.text(PlotConfig.NOTE_X_POS,
|
|
221
|
+
PlotConfig.NOTE_Y_POS,
|
|
222
|
+
outliers_note,
|
|
223
|
+
transform=ax.transAxes,
|
|
224
|
+
fontsize=PlotConfig.NOTE_FONTSIZE,
|
|
225
|
+
color=note_text_color,
|
|
226
|
+
ha='right',
|
|
227
|
+
va='bottom',
|
|
228
|
+
bbox=dict(boxstyle=f'round,pad={PlotConfig.NOTE_BOX_PAD}',
|
|
229
|
+
facecolor=note_box_color,
|
|
230
|
+
alpha=PlotConfig.NOTE_BOX_ALPHA))
|
|
231
|
+
|
|
232
|
+
return True
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def plot_concurrency_vs_time_metrics(metrics_per_concurrency: dict[int, SizingMetrics],
|
|
236
|
+
output_dir: Path,
|
|
237
|
+
target_llm_latency: float = 0.0,
|
|
238
|
+
target_runtime: float = 0.0,
|
|
239
|
+
llm_latency_fit: LinearFitResult | None = None,
|
|
240
|
+
runtime_fit: LinearFitResult | None = None) -> None:
|
|
241
|
+
"""
|
|
242
|
+
Plot concurrency vs. p95 latency and workflow runtime using metrics_per_concurrency.
|
|
243
|
+
Enhanced with better styling, trend analysis, and annotations.
|
|
244
|
+
Only plots valid runs and requires pre-computed fits.
|
|
245
|
+
"""
|
|
246
|
+
rows = []
|
|
247
|
+
|
|
248
|
+
for concurrency, metrics in metrics_per_concurrency.items():
|
|
249
|
+
llm_latency = metrics.llm_latency_p95
|
|
250
|
+
workflow_runtime = metrics.workflow_runtime_p95
|
|
251
|
+
|
|
252
|
+
rows.append({
|
|
253
|
+
"concurrency": concurrency, "llm_latency_p95": llm_latency, "workflow_runtime_p95": workflow_runtime
|
|
254
|
+
})
|
|
255
|
+
|
|
256
|
+
if not rows:
|
|
257
|
+
logger.warning("No valid metrics data available to plot.")
|
|
258
|
+
return
|
|
259
|
+
|
|
260
|
+
plt.style.use('seaborn-v0_8')
|
|
261
|
+
df = pd.DataFrame(rows).sort_values("concurrency")
|
|
262
|
+
|
|
263
|
+
# Always generate simple plot first
|
|
264
|
+
plot_concurrency_vs_time_metrics_simple(df, output_dir)
|
|
265
|
+
|
|
266
|
+
# Check if we have fits available for analysis plots
|
|
267
|
+
has_llm_latency_fit = llm_latency_fit is not None
|
|
268
|
+
has_runtime_fit = runtime_fit is not None
|
|
269
|
+
|
|
270
|
+
if not has_llm_latency_fit and not has_runtime_fit:
|
|
271
|
+
logger.warning("No linear fits available for analysis plots, skipping enhanced plot")
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
# Create subplots based on available fits
|
|
275
|
+
if has_llm_latency_fit and has_runtime_fit:
|
|
276
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=PlotConfig.ENHANCED_FIGSIZE)
|
|
277
|
+
else:
|
|
278
|
+
fig, ax1 = plt.subplots(1, 1, figsize=(8, 6))
|
|
279
|
+
ax2 = None
|
|
280
|
+
|
|
281
|
+
# Plot llm_latency if fit is available
|
|
282
|
+
llm_latency_plotted = False
|
|
283
|
+
if has_llm_latency_fit:
|
|
284
|
+
llm_latency_plotted = plot_metric_vs_concurrency_with_optional_fit(
|
|
285
|
+
ax1,
|
|
286
|
+
df["concurrency"].to_numpy(),
|
|
287
|
+
df["llm_latency_p95"].to_numpy(),
|
|
288
|
+
metric_name="llm_latency",
|
|
289
|
+
y_label='P95 LLM Latency (seconds)',
|
|
290
|
+
title='Concurrency vs P95 LLM Latency',
|
|
291
|
+
color=PlotConfig.LLM_LATENCY_COLOR,
|
|
292
|
+
sla_value=target_llm_latency,
|
|
293
|
+
sla_label=f'SLA Threshold ({target_llm_latency}s)' if target_llm_latency > 0 else None,
|
|
294
|
+
fit=llm_latency_fit,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# Plot runtime if fit is available
|
|
298
|
+
runtime_plotted = False
|
|
299
|
+
if has_runtime_fit and ax2 is not None:
|
|
300
|
+
runtime_plotted = plot_metric_vs_concurrency_with_optional_fit(
|
|
301
|
+
ax2,
|
|
302
|
+
df["concurrency"].to_numpy(),
|
|
303
|
+
df["workflow_runtime_p95"].to_numpy(),
|
|
304
|
+
metric_name="runtime",
|
|
305
|
+
y_label='P95 Workflow Runtime (seconds)',
|
|
306
|
+
title='Concurrency vs P95 Workflow Runtime',
|
|
307
|
+
color=PlotConfig.RUNTIME_COLOR,
|
|
308
|
+
sla_value=target_runtime,
|
|
309
|
+
sla_label=f'SLA Threshold ({target_runtime}s)' if target_runtime > 0 else None,
|
|
310
|
+
fit=runtime_fit,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Check if any plots were successfully created
|
|
314
|
+
plots_created = (llm_latency_plotted or runtime_plotted)
|
|
315
|
+
|
|
316
|
+
if not plots_created:
|
|
317
|
+
logger.warning("No analysis plots could be created, skipping enhanced plot")
|
|
318
|
+
plt.close(fig)
|
|
319
|
+
return
|
|
320
|
+
|
|
321
|
+
# Add summary statistics
|
|
322
|
+
stats_text = f'Data Points: {len(df)}\n'
|
|
323
|
+
stats_text += f'LLM Latency Range: {df["llm_latency_p95"].min():.3f}-{df["llm_latency_p95"].max():.3f}s\n'
|
|
324
|
+
stats_text += f'WF Runtime Range: {df["workflow_runtime_p95"].min():.3f}-{df["workflow_runtime_p95"].max():.3f}s'
|
|
325
|
+
|
|
326
|
+
fig.text(PlotConfig.STATS_X_POS,
|
|
327
|
+
PlotConfig.STATS_Y_POS,
|
|
328
|
+
stats_text,
|
|
329
|
+
fontsize=PlotConfig.STATS_FONTSIZE,
|
|
330
|
+
bbox=dict(boxstyle=f'round,pad={PlotConfig.STATS_BOX_PAD}',
|
|
331
|
+
facecolor=PlotConfig.STATS_BOX_COLOR,
|
|
332
|
+
alpha=PlotConfig.STATS_BOX_ALPHA))
|
|
333
|
+
|
|
334
|
+
plt.tight_layout()
|
|
335
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
336
|
+
|
|
337
|
+
enhanced_plot_path = output_dir / "concurrency_vs_p95_analysis.png"
|
|
338
|
+
plt.savefig(enhanced_plot_path,
|
|
339
|
+
dpi=PlotConfig.ENHANCED_DPI,
|
|
340
|
+
bbox_inches='tight',
|
|
341
|
+
facecolor='white',
|
|
342
|
+
edgecolor='none')
|
|
343
|
+
plt.close()
|
|
344
|
+
|
|
345
|
+
logger.info("Enhanced plot saved to %s", enhanced_plot_path)
|
aiq/profiler/data_models.py
CHANGED
|
@@ -15,8 +15,10 @@
|
|
|
15
15
|
|
|
16
16
|
from pydantic import BaseModel
|
|
17
17
|
|
|
18
|
+
from aiq.profiler.inference_metrics_model import InferenceMetricsModel
|
|
18
19
|
from aiq.profiler.inference_optimization.data_models import WorkflowRuntimeMetrics
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class ProfilerResults(BaseModel):
|
|
22
23
|
workflow_runtime_metrics: WorkflowRuntimeMetrics | None = None
|
|
24
|
+
llm_latency_ci: InferenceMetricsModel | None = None
|
aiq/profiler/profile_runner.py
CHANGED
|
@@ -68,9 +68,10 @@ class ProfilerRunner:
|
|
|
68
68
|
All computed metrics are saved to a metrics JSON file at the end.
|
|
69
69
|
"""
|
|
70
70
|
|
|
71
|
-
def __init__(self, profiler_config: ProfilerConfig, output_dir: Path):
|
|
71
|
+
def __init__(self, profiler_config: ProfilerConfig, output_dir: Path, write_output: bool = True):
|
|
72
72
|
self.profile_config = profiler_config
|
|
73
73
|
self.output_dir = output_dir
|
|
74
|
+
self.write_output = write_output
|
|
74
75
|
self._converter = TypeConverter([])
|
|
75
76
|
|
|
76
77
|
# Holds per-request data (prompt, output, usage_stats, etc.)
|
|
@@ -114,10 +115,11 @@ class ProfilerRunner:
|
|
|
114
115
|
self.all_requests_data.append({"request_number": i, "intermediate_steps": request_data})
|
|
115
116
|
|
|
116
117
|
# Write the final big JSON (all requests)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
118
|
+
if self.write_output:
|
|
119
|
+
final_path = os.path.join(self.output_dir, "all_requests_profiler_traces.json")
|
|
120
|
+
with open(final_path, 'w', encoding='utf-8') as f:
|
|
121
|
+
json.dump(self.all_requests_data, f, indent=2, default=str)
|
|
122
|
+
logger.info("Wrote combined data to: %s", final_path)
|
|
121
123
|
|
|
122
124
|
# ------------------------------------------------------------
|
|
123
125
|
# Generate one standardized dataframe for all usage stats
|
|
@@ -185,7 +187,7 @@ class ProfilerRunner:
|
|
|
185
187
|
token_uniqueness=token_uniqueness_results,
|
|
186
188
|
workflow_runtimes=workflow_runtimes_results)
|
|
187
189
|
|
|
188
|
-
if inference_optimization_results:
|
|
190
|
+
if self.write_output and inference_optimization_results:
|
|
189
191
|
# Save to JSON
|
|
190
192
|
optimization_results_path = os.path.join(self.output_dir, "inference_optimization.json")
|
|
191
193
|
with open(optimization_results_path, 'w', encoding='utf-8') as f:
|
|
@@ -249,14 +251,14 @@ class ProfilerRunner:
|
|
|
249
251
|
exclude=["textual_report"])
|
|
250
252
|
logger.info("Prefix span analysis complete")
|
|
251
253
|
|
|
252
|
-
if workflow_profiling_reports:
|
|
254
|
+
if self.write_output and workflow_profiling_reports:
|
|
253
255
|
# Save to text file
|
|
254
256
|
profiling_report_path = os.path.join(self.output_dir, "workflow_profiling_report.txt")
|
|
255
257
|
with open(profiling_report_path, 'w', encoding='utf-8') as f:
|
|
256
258
|
f.write(workflow_profiling_reports)
|
|
257
259
|
logger.info("Wrote workflow profiling report to: %s", profiling_report_path)
|
|
258
260
|
|
|
259
|
-
if workflow_profiling_metrics:
|
|
261
|
+
if self.write_output and workflow_profiling_metrics:
|
|
260
262
|
# Save to JSON
|
|
261
263
|
profiling_metrics_path = os.path.join(self.output_dir, "workflow_profiling_metrics.json")
|
|
262
264
|
with open(profiling_metrics_path, 'w', encoding='utf-8') as f:
|
|
@@ -278,15 +280,16 @@ class ProfilerRunner:
|
|
|
278
280
|
logger.exception("Fitting model failed. %s", e, exc_info=True)
|
|
279
281
|
return ProfilerResults()
|
|
280
282
|
|
|
281
|
-
|
|
283
|
+
if self.write_output:
|
|
284
|
+
os.makedirs(self.output_dir, exist_ok=True)
|
|
282
285
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
+
import pickle
|
|
287
|
+
with open(os.path.join(self.output_dir, "fitted_model.pkl"), 'wb') as f:
|
|
288
|
+
pickle.dump(fitted_model, f)
|
|
286
289
|
|
|
287
290
|
logger.info("Saved fitted model to disk.")
|
|
288
291
|
|
|
289
|
-
return ProfilerResults(workflow_runtime_metrics=workflow_runtimes_results)
|
|
292
|
+
return ProfilerResults(workflow_runtime_metrics=workflow_runtimes_results, llm_latency_ci=llm_latency_ci)
|
|
290
293
|
|
|
291
294
|
# -------------------------------------------------------------------
|
|
292
295
|
# Confidence Intervals / Metrics
|
aiq/runtime/loader.py
CHANGED
|
@@ -48,14 +48,18 @@ class PluginTypes(IntFlag):
|
|
|
48
48
|
"""
|
|
49
49
|
A plugin that is an evaluator for the workflow. This includes evaluators like RAGAS, SWE-bench, etc.
|
|
50
50
|
"""
|
|
51
|
+
AUTHENTICATION = auto()
|
|
52
|
+
"""
|
|
53
|
+
A plugin that is an API authentication provider for the workflow. This includes Oauth2, API Key, etc.
|
|
54
|
+
"""
|
|
51
55
|
REGISTRY_HANDLER = auto()
|
|
52
56
|
|
|
53
57
|
# Convenience flag for groups of plugin types
|
|
54
|
-
CONFIG_OBJECT = COMPONENT | FRONT_END | EVALUATOR
|
|
58
|
+
CONFIG_OBJECT = COMPONENT | FRONT_END | EVALUATOR | AUTHENTICATION
|
|
55
59
|
"""
|
|
56
60
|
Any plugin that can be specified in the AIQ Toolkit configuration file.
|
|
57
61
|
"""
|
|
58
|
-
ALL = COMPONENT | FRONT_END | EVALUATOR | REGISTRY_HANDLER
|
|
62
|
+
ALL = COMPONENT | FRONT_END | EVALUATOR | REGISTRY_HANDLER | AUTHENTICATION
|
|
59
63
|
"""
|
|
60
64
|
All plugin types
|
|
61
65
|
"""
|
|
@@ -130,6 +134,8 @@ def discover_entrypoints(plugin_type: PluginTypes):
|
|
|
130
134
|
plugin_groups.append("aiq.registry_handlers")
|
|
131
135
|
if (plugin_type & PluginTypes.EVALUATOR):
|
|
132
136
|
plugin_groups.append("aiq.evaluators")
|
|
137
|
+
if (plugin_type & PluginTypes.AUTHENTICATION):
|
|
138
|
+
plugin_groups.append("aiq.authentication_providers")
|
|
133
139
|
|
|
134
140
|
# Get the entry points for the specified groups
|
|
135
141
|
aiq_plugins = reduce(lambda x, y: list(x) + list(y), [entry_points.select(group=y) for y in plugin_groups])
|
aiq/runtime/runner.py
CHANGED
|
@@ -21,7 +21,7 @@ from aiq.builder.context import AIQContext
|
|
|
21
21
|
from aiq.builder.context import AIQContextState
|
|
22
22
|
from aiq.builder.function import Function
|
|
23
23
|
from aiq.data_models.invocation_node import InvocationNode
|
|
24
|
-
from aiq.observability.
|
|
24
|
+
from aiq.observability.exporter_manager import ExporterManager
|
|
25
25
|
from aiq.utils.reactive.subject import Subject
|
|
26
26
|
|
|
27
27
|
logger = logging.getLogger(__name__)
|
|
@@ -44,7 +44,11 @@ _T = typing.TypeVar("_T")
|
|
|
44
44
|
|
|
45
45
|
class AIQRunner:
|
|
46
46
|
|
|
47
|
-
def __init__(self,
|
|
47
|
+
def __init__(self,
|
|
48
|
+
input_message: typing.Any,
|
|
49
|
+
entry_fn: Function,
|
|
50
|
+
context_state: AIQContextState,
|
|
51
|
+
exporter_manager: ExporterManager):
|
|
48
52
|
"""
|
|
49
53
|
The AIQRunner class is used to run a workflow. It handles converting input and output data types and running the
|
|
50
54
|
workflow with the specified concurrency.
|
|
@@ -57,6 +61,8 @@ class AIQRunner:
|
|
|
57
61
|
The entry function to the workflow
|
|
58
62
|
context_state : AIQContextState
|
|
59
63
|
The context state to use
|
|
64
|
+
exporter_manager : ExporterManager
|
|
65
|
+
The exporter manager to use
|
|
60
66
|
"""
|
|
61
67
|
|
|
62
68
|
if (entry_fn is None):
|
|
@@ -73,7 +79,7 @@ class AIQRunner:
|
|
|
73
79
|
# Before we start, we need to convert the input message to the workflow input type
|
|
74
80
|
self._input_message = input_message
|
|
75
81
|
|
|
76
|
-
self.
|
|
82
|
+
self._exporter_manager = exporter_manager
|
|
77
83
|
|
|
78
84
|
@property
|
|
79
85
|
def context(self) -> AIQContext:
|
|
@@ -130,19 +136,23 @@ class AIQRunner:
|
|
|
130
136
|
if (not self._entry_fn.has_single_output):
|
|
131
137
|
raise ValueError("Workflow does not support single output")
|
|
132
138
|
|
|
133
|
-
async with self.
|
|
139
|
+
async with self._exporter_manager.start(context_state=self._context_state):
|
|
134
140
|
# Run the workflow
|
|
135
141
|
result = await self._entry_fn.ainvoke(self._input_message, to_type=to_type)
|
|
136
142
|
|
|
137
143
|
# Close the intermediate stream
|
|
138
|
-
self._context_state.event_stream.get()
|
|
144
|
+
event_stream = self._context_state.event_stream.get()
|
|
145
|
+
if event_stream:
|
|
146
|
+
event_stream.on_complete()
|
|
139
147
|
|
|
140
148
|
self._state = AIQRunnerState.COMPLETED
|
|
141
149
|
|
|
142
150
|
return result
|
|
143
151
|
except Exception as e:
|
|
144
152
|
logger.exception("Error running workflow: %s", e)
|
|
145
|
-
self._context_state.event_stream.get()
|
|
153
|
+
event_stream = self._context_state.event_stream.get()
|
|
154
|
+
if event_stream:
|
|
155
|
+
event_stream.on_complete()
|
|
146
156
|
self._state = AIQRunnerState.FAILED
|
|
147
157
|
|
|
148
158
|
raise
|
|
@@ -159,18 +169,22 @@ class AIQRunner:
|
|
|
159
169
|
raise ValueError("Workflow does not support streaming output")
|
|
160
170
|
|
|
161
171
|
# Run the workflow
|
|
162
|
-
async with self.
|
|
172
|
+
async with self._exporter_manager.start(context_state=self._context_state):
|
|
163
173
|
async for m in self._entry_fn.astream(self._input_message, to_type=to_type):
|
|
164
174
|
yield m
|
|
165
175
|
|
|
166
176
|
self._state = AIQRunnerState.COMPLETED
|
|
167
177
|
|
|
168
178
|
# Close the intermediate stream
|
|
169
|
-
self._context_state.event_stream.get()
|
|
179
|
+
event_stream = self._context_state.event_stream.get()
|
|
180
|
+
if event_stream:
|
|
181
|
+
event_stream.on_complete()
|
|
170
182
|
|
|
171
183
|
except Exception as e:
|
|
172
184
|
logger.exception("Error running workflow: %s", e)
|
|
173
|
-
self._context_state.event_stream.get()
|
|
185
|
+
event_stream = self._context_state.event_stream.get()
|
|
186
|
+
if event_stream:
|
|
187
|
+
event_stream.on_complete()
|
|
174
188
|
self._state = AIQRunnerState.FAILED
|
|
175
189
|
|
|
176
190
|
raise
|
aiq/runtime/session.py
CHANGED
|
@@ -21,11 +21,14 @@ from collections.abc import Callable
|
|
|
21
21
|
from contextlib import asynccontextmanager
|
|
22
22
|
from contextlib import nullcontext
|
|
23
23
|
|
|
24
|
-
from
|
|
24
|
+
from starlette.requests import HTTPConnection
|
|
25
25
|
|
|
26
26
|
from aiq.builder.context import AIQContext
|
|
27
27
|
from aiq.builder.context import AIQContextState
|
|
28
28
|
from aiq.builder.workflow import Workflow
|
|
29
|
+
from aiq.data_models.authentication import AuthenticatedContext
|
|
30
|
+
from aiq.data_models.authentication import AuthFlowType
|
|
31
|
+
from aiq.data_models.authentication import AuthProviderBaseConfig
|
|
29
32
|
from aiq.data_models.config import AIQConfig
|
|
30
33
|
from aiq.data_models.interactive import HumanResponse
|
|
31
34
|
from aiq.data_models.interactive import InteractionPrompt
|
|
@@ -86,9 +89,11 @@ class AIQSessionManager:
|
|
|
86
89
|
@asynccontextmanager
|
|
87
90
|
async def session(self,
|
|
88
91
|
user_manager=None,
|
|
89
|
-
request:
|
|
92
|
+
request: HTTPConnection | None = None,
|
|
90
93
|
conversation_id: str | None = None,
|
|
91
|
-
user_input_callback: Callable[[InteractionPrompt], Awaitable[HumanResponse]] = None
|
|
94
|
+
user_input_callback: Callable[[InteractionPrompt], Awaitable[HumanResponse]] = None,
|
|
95
|
+
user_authentication_callback: Callable[[AuthProviderBaseConfig, AuthFlowType],
|
|
96
|
+
Awaitable[AuthenticatedContext | None]] = None):
|
|
92
97
|
|
|
93
98
|
token_user_input = None
|
|
94
99
|
if user_input_callback is not None:
|
|
@@ -98,6 +103,10 @@ class AIQSessionManager:
|
|
|
98
103
|
if user_manager is not None:
|
|
99
104
|
token_user_manager = self._context_state.user_manager.set(user_manager)
|
|
100
105
|
|
|
106
|
+
token_user_authentication = None
|
|
107
|
+
if user_authentication_callback is not None:
|
|
108
|
+
token_user_authentication = self._context_state.user_auth_callback.set(user_authentication_callback)
|
|
109
|
+
|
|
101
110
|
if conversation_id is not None and request is None:
|
|
102
111
|
self._context_state.conversation_id.set(conversation_id)
|
|
103
112
|
|
|
@@ -110,6 +119,8 @@ class AIQSessionManager:
|
|
|
110
119
|
self._context_state.user_manager.reset(token_user_manager)
|
|
111
120
|
if token_user_input is not None:
|
|
112
121
|
self._context_state.user_input_callback.reset(token_user_input)
|
|
122
|
+
if token_user_authentication is not None:
|
|
123
|
+
self._context_state.user_auth_callback.reset(token_user_authentication)
|
|
113
124
|
|
|
114
125
|
@asynccontextmanager
|
|
115
126
|
async def run(self, message):
|
|
@@ -124,7 +135,7 @@ class AIQSessionManager:
|
|
|
124
135
|
async with self._workflow.run(message) as runner:
|
|
125
136
|
yield runner
|
|
126
137
|
|
|
127
|
-
def set_metadata_from_http_request(self, request:
|
|
138
|
+
def set_metadata_from_http_request(self, request: HTTPConnection | None) -> None:
|
|
128
139
|
"""
|
|
129
140
|
Extracts and sets user metadata request attributes from a HTTP request.
|
|
130
141
|
If request is None, no attributes are set.
|
|
@@ -132,7 +143,7 @@ class AIQSessionManager:
|
|
|
132
143
|
if request is None:
|
|
133
144
|
return
|
|
134
145
|
|
|
135
|
-
self._context.metadata._request.method = request
|
|
146
|
+
self._context.metadata._request.method = getattr(request, "method", None)
|
|
136
147
|
self._context.metadata._request.url_path = request.url.path
|
|
137
148
|
self._context.metadata._request.url_port = request.url.port
|
|
138
149
|
self._context.metadata._request.url_scheme = request.url.scheme
|