aiqtoolkit 1.2.0.dev0__py3-none-any.whl → 1.2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/base.py +170 -8
- aiq/agent/dual_node.py +1 -1
- aiq/agent/react_agent/agent.py +146 -112
- aiq/agent/react_agent/prompt.py +1 -6
- aiq/agent/react_agent/register.py +36 -35
- aiq/agent/rewoo_agent/agent.py +36 -35
- aiq/agent/rewoo_agent/register.py +2 -2
- aiq/agent/tool_calling_agent/agent.py +3 -7
- aiq/agent/tool_calling_agent/register.py +1 -1
- aiq/authentication/__init__.py +14 -0
- aiq/authentication/api_key/__init__.py +14 -0
- aiq/authentication/api_key/api_key_auth_provider.py +92 -0
- aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
- aiq/authentication/api_key/register.py +26 -0
- aiq/authentication/exceptions/__init__.py +14 -0
- aiq/authentication/exceptions/api_key_exceptions.py +38 -0
- aiq/authentication/exceptions/auth_code_grant_exceptions.py +86 -0
- aiq/authentication/exceptions/call_back_exceptions.py +38 -0
- aiq/authentication/exceptions/request_exceptions.py +54 -0
- aiq/authentication/http_basic_auth/__init__.py +0 -0
- aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- aiq/authentication/http_basic_auth/register.py +30 -0
- aiq/authentication/interfaces.py +93 -0
- aiq/authentication/oauth2/__init__.py +14 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- aiq/authentication/oauth2/register.py +25 -0
- aiq/authentication/register.py +21 -0
- aiq/builder/builder.py +64 -2
- aiq/builder/component_utils.py +16 -3
- aiq/builder/context.py +37 -0
- aiq/builder/eval_builder.py +43 -2
- aiq/builder/function.py +44 -12
- aiq/builder/function_base.py +1 -1
- aiq/builder/intermediate_step_manager.py +6 -8
- aiq/builder/user_interaction_manager.py +3 -0
- aiq/builder/workflow.py +23 -18
- aiq/builder/workflow_builder.py +421 -61
- aiq/cli/commands/info/list_mcp.py +103 -16
- aiq/cli/commands/sizing/__init__.py +14 -0
- aiq/cli/commands/sizing/calc.py +294 -0
- aiq/cli/commands/sizing/sizing.py +27 -0
- aiq/cli/commands/start.py +2 -1
- aiq/cli/entrypoint.py +2 -0
- aiq/cli/register_workflow.py +80 -0
- aiq/cli/type_registry.py +151 -30
- aiq/data_models/api_server.py +124 -12
- aiq/data_models/authentication.py +231 -0
- aiq/data_models/common.py +35 -7
- aiq/data_models/component.py +17 -9
- aiq/data_models/component_ref.py +33 -0
- aiq/data_models/config.py +60 -3
- aiq/data_models/dataset_handler.py +2 -1
- aiq/data_models/embedder.py +1 -0
- aiq/data_models/evaluate.py +23 -0
- aiq/data_models/function_dependencies.py +8 -0
- aiq/data_models/interactive.py +10 -1
- aiq/data_models/intermediate_step.py +38 -5
- aiq/data_models/its_strategy.py +30 -0
- aiq/data_models/llm.py +1 -0
- aiq/data_models/memory.py +1 -0
- aiq/data_models/object_store.py +44 -0
- aiq/data_models/profiler.py +1 -0
- aiq/data_models/retry_mixin.py +35 -0
- aiq/data_models/span.py +187 -0
- aiq/data_models/telemetry_exporter.py +2 -2
- aiq/embedder/nim_embedder.py +2 -1
- aiq/embedder/openai_embedder.py +2 -1
- aiq/eval/config.py +19 -1
- aiq/eval/dataset_handler/dataset_handler.py +87 -2
- aiq/eval/evaluate.py +208 -27
- aiq/eval/evaluator/base_evaluator.py +73 -0
- aiq/eval/evaluator/evaluator_model.py +1 -0
- aiq/eval/intermediate_step_adapter.py +11 -5
- aiq/eval/rag_evaluator/evaluate.py +55 -15
- aiq/eval/rag_evaluator/register.py +6 -1
- aiq/eval/remote_workflow.py +7 -2
- aiq/eval/runners/__init__.py +14 -0
- aiq/eval/runners/config.py +39 -0
- aiq/eval/runners/multi_eval_runner.py +54 -0
- aiq/eval/trajectory_evaluator/evaluate.py +22 -65
- aiq/eval/tunable_rag_evaluator/evaluate.py +150 -168
- aiq/eval/tunable_rag_evaluator/register.py +2 -0
- aiq/eval/usage_stats.py +41 -0
- aiq/eval/utils/output_uploader.py +10 -1
- aiq/eval/utils/weave_eval.py +184 -0
- aiq/experimental/__init__.py +0 -0
- aiq/experimental/decorators/__init__.py +0 -0
- aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
- aiq/experimental/inference_time_scaling/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/editing/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/editing/iterative_plan_refinement_editor.py +147 -0
- aiq/experimental/inference_time_scaling/editing/llm_as_a_judge_editor.py +204 -0
- aiq/experimental/inference_time_scaling/editing/motivation_aware_summarization.py +107 -0
- aiq/experimental/inference_time_scaling/functions/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/functions/execute_score_select_function.py +105 -0
- aiq/experimental/inference_time_scaling/functions/its_tool_orchestration_function.py +205 -0
- aiq/experimental/inference_time_scaling/functions/its_tool_wrapper_function.py +146 -0
- aiq/experimental/inference_time_scaling/functions/plan_select_execute_function.py +224 -0
- aiq/experimental/inference_time_scaling/models/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/models/editor_config.py +132 -0
- aiq/experimental/inference_time_scaling/models/its_item.py +48 -0
- aiq/experimental/inference_time_scaling/models/scoring_config.py +112 -0
- aiq/experimental/inference_time_scaling/models/search_config.py +120 -0
- aiq/experimental/inference_time_scaling/models/selection_config.py +154 -0
- aiq/experimental/inference_time_scaling/models/stage_enums.py +43 -0
- aiq/experimental/inference_time_scaling/models/strategy_base.py +66 -0
- aiq/experimental/inference_time_scaling/models/tool_use_config.py +41 -0
- aiq/experimental/inference_time_scaling/register.py +36 -0
- aiq/experimental/inference_time_scaling/scoring/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/scoring/llm_based_agent_scorer.py +168 -0
- aiq/experimental/inference_time_scaling/scoring/llm_based_plan_scorer.py +168 -0
- aiq/experimental/inference_time_scaling/scoring/motivation_aware_scorer.py +111 -0
- aiq/experimental/inference_time_scaling/search/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/search/multi_llm_planner.py +128 -0
- aiq/experimental/inference_time_scaling/search/multi_query_retrieval_search.py +122 -0
- aiq/experimental/inference_time_scaling/search/single_shot_multi_plan_planner.py +128 -0
- aiq/experimental/inference_time_scaling/selection/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/selection/best_of_n_selector.py +63 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_agent_output_selector.py +131 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_output_merging_selector.py +159 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_plan_selector.py +128 -0
- aiq/experimental/inference_time_scaling/selection/threshold_selector.py +58 -0
- aiq/front_ends/console/authentication_flow_handler.py +233 -0
- aiq/front_ends/console/console_front_end_plugin.py +11 -2
- aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +93 -9
- aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +14 -1
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +537 -52
- aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
- aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- aiq/front_ends/fastapi/job_store.py +47 -25
- aiq/front_ends/fastapi/main.py +2 -0
- aiq/front_ends/fastapi/message_handler.py +108 -89
- aiq/front_ends/fastapi/step_adaptor.py +2 -1
- aiq/llm/aws_bedrock_llm.py +57 -0
- aiq/llm/nim_llm.py +2 -1
- aiq/llm/openai_llm.py +3 -2
- aiq/llm/register.py +1 -0
- aiq/meta/pypi.md +12 -12
- aiq/object_store/__init__.py +20 -0
- aiq/object_store/in_memory_object_store.py +74 -0
- aiq/object_store/interfaces.py +84 -0
- aiq/object_store/models.py +36 -0
- aiq/object_store/register.py +20 -0
- aiq/observability/__init__.py +14 -0
- aiq/observability/exporter/__init__.py +14 -0
- aiq/observability/exporter/base_exporter.py +449 -0
- aiq/observability/exporter/exporter.py +78 -0
- aiq/observability/exporter/file_exporter.py +33 -0
- aiq/observability/exporter/processing_exporter.py +269 -0
- aiq/observability/exporter/raw_exporter.py +52 -0
- aiq/observability/exporter/span_exporter.py +264 -0
- aiq/observability/exporter_manager.py +335 -0
- aiq/observability/mixin/__init__.py +14 -0
- aiq/observability/mixin/batch_config_mixin.py +26 -0
- aiq/observability/mixin/collector_config_mixin.py +23 -0
- aiq/observability/mixin/file_mixin.py +288 -0
- aiq/observability/mixin/file_mode.py +23 -0
- aiq/observability/mixin/resource_conflict_mixin.py +134 -0
- aiq/observability/mixin/serialize_mixin.py +61 -0
- aiq/observability/mixin/type_introspection_mixin.py +183 -0
- aiq/observability/processor/__init__.py +14 -0
- aiq/observability/processor/batching_processor.py +316 -0
- aiq/observability/processor/intermediate_step_serializer.py +28 -0
- aiq/observability/processor/processor.py +68 -0
- aiq/observability/register.py +36 -39
- aiq/observability/utils/__init__.py +14 -0
- aiq/observability/utils/dict_utils.py +236 -0
- aiq/observability/utils/time_utils.py +31 -0
- aiq/profiler/calc/__init__.py +14 -0
- aiq/profiler/calc/calc_runner.py +623 -0
- aiq/profiler/calc/calculations.py +288 -0
- aiq/profiler/calc/data_models.py +176 -0
- aiq/profiler/calc/plot.py +345 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +22 -10
- aiq/profiler/data_models.py +24 -0
- aiq/profiler/inference_metrics_model.py +3 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +8 -0
- aiq/profiler/inference_optimization/data_models.py +2 -2
- aiq/profiler/inference_optimization/llm_metrics.py +2 -2
- aiq/profiler/profile_runner.py +61 -21
- aiq/runtime/loader.py +9 -3
- aiq/runtime/runner.py +23 -9
- aiq/runtime/session.py +25 -7
- aiq/runtime/user_metadata.py +2 -3
- aiq/tool/chat_completion.py +74 -0
- aiq/tool/code_execution/README.md +152 -0
- aiq/tool/code_execution/code_sandbox.py +151 -72
- aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +139 -24
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +3 -1
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +27 -2
- aiq/tool/code_execution/register.py +7 -3
- aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
- aiq/tool/mcp/exceptions.py +142 -0
- aiq/tool/mcp/mcp_client.py +41 -6
- aiq/tool/mcp/mcp_tool.py +3 -2
- aiq/tool/register.py +1 -0
- aiq/tool/server_tools.py +6 -3
- aiq/utils/exception_handlers/automatic_retries.py +289 -0
- aiq/utils/exception_handlers/mcp.py +211 -0
- aiq/utils/io/model_processing.py +28 -0
- aiq/utils/log_utils.py +37 -0
- aiq/utils/string_utils.py +38 -0
- aiq/utils/type_converter.py +18 -2
- aiq/utils/type_utils.py +87 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/METADATA +53 -21
- aiqtoolkit-1.2.0rc1.dist-info/RECORD +436 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/WHEEL +1 -1
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/entry_points.txt +3 -0
- aiq/front_ends/fastapi/websocket.py +0 -148
- aiq/observability/async_otel_listener.py +0 -429
- aiqtoolkit-1.2.0.dev0.dist-info/RECORD +0 -316
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -16,11 +16,13 @@
|
|
|
16
16
|
import asyncio
|
|
17
17
|
import logging
|
|
18
18
|
import os
|
|
19
|
+
import time
|
|
19
20
|
import typing
|
|
20
21
|
from abc import ABC
|
|
21
22
|
from abc import abstractmethod
|
|
23
|
+
from collections.abc import Awaitable
|
|
24
|
+
from collections.abc import Callable
|
|
22
25
|
from contextlib import asynccontextmanager
|
|
23
|
-
from functools import partial
|
|
24
26
|
from pathlib import Path
|
|
25
27
|
|
|
26
28
|
from fastapi import BackgroundTasks
|
|
@@ -28,10 +30,13 @@ from fastapi import Body
|
|
|
28
30
|
from fastapi import FastAPI
|
|
29
31
|
from fastapi import Request
|
|
30
32
|
from fastapi import Response
|
|
33
|
+
from fastapi import UploadFile
|
|
31
34
|
from fastapi.exceptions import HTTPException
|
|
32
35
|
from fastapi.middleware.cors import CORSMiddleware
|
|
33
36
|
from fastapi.responses import StreamingResponse
|
|
34
37
|
from pydantic import BaseModel
|
|
38
|
+
from pydantic import Field
|
|
39
|
+
from starlette.websockets import WebSocket
|
|
35
40
|
|
|
36
41
|
from aiq.builder.workflow_builder import WorkflowBuilder
|
|
37
42
|
from aiq.data_models.api_server import AIQChatRequest
|
|
@@ -39,20 +44,28 @@ from aiq.data_models.api_server import AIQChatResponse
|
|
|
39
44
|
from aiq.data_models.api_server import AIQChatResponseChunk
|
|
40
45
|
from aiq.data_models.api_server import AIQResponseIntermediateStep
|
|
41
46
|
from aiq.data_models.config import AIQConfig
|
|
47
|
+
from aiq.data_models.object_store import KeyAlreadyExistsError
|
|
48
|
+
from aiq.data_models.object_store import NoSuchKeyError
|
|
42
49
|
from aiq.eval.config import EvaluationRunOutput
|
|
43
50
|
from aiq.eval.evaluate import EvaluationRun
|
|
44
51
|
from aiq.eval.evaluate import EvaluationRunConfig
|
|
52
|
+
from aiq.front_ends.fastapi.auth_flow_handlers.http_flow_handler import HTTPAuthenticationFlowHandler
|
|
53
|
+
from aiq.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import FlowState
|
|
54
|
+
from aiq.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import WebSocketAuthenticationFlowHandler
|
|
55
|
+
from aiq.front_ends.fastapi.fastapi_front_end_config import AIQAsyncGenerateResponse
|
|
56
|
+
from aiq.front_ends.fastapi.fastapi_front_end_config import AIQAsyncGenerationStatusResponse
|
|
45
57
|
from aiq.front_ends.fastapi.fastapi_front_end_config import AIQEvaluateRequest
|
|
46
58
|
from aiq.front_ends.fastapi.fastapi_front_end_config import AIQEvaluateResponse
|
|
47
59
|
from aiq.front_ends.fastapi.fastapi_front_end_config import AIQEvaluateStatusResponse
|
|
48
60
|
from aiq.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
|
|
49
61
|
from aiq.front_ends.fastapi.job_store import JobInfo
|
|
50
62
|
from aiq.front_ends.fastapi.job_store import JobStore
|
|
63
|
+
from aiq.front_ends.fastapi.message_handler import WebSocketMessageHandler
|
|
51
64
|
from aiq.front_ends.fastapi.response_helpers import generate_single_response
|
|
52
65
|
from aiq.front_ends.fastapi.response_helpers import generate_streaming_response_as_str
|
|
53
66
|
from aiq.front_ends.fastapi.response_helpers import generate_streaming_response_full_as_str
|
|
54
67
|
from aiq.front_ends.fastapi.step_adaptor import StepAdaptor
|
|
55
|
-
from aiq.
|
|
68
|
+
from aiq.object_store.models import ObjectStoreItem
|
|
56
69
|
from aiq.runtime.session import AIQSessionManager
|
|
57
70
|
|
|
58
71
|
logger = logging.getLogger(__name__)
|
|
@@ -68,13 +81,16 @@ class FastApiFrontEndPluginWorkerBase(ABC):
|
|
|
68
81
|
|
|
69
82
|
self._front_end_config = config.general.front_end
|
|
70
83
|
|
|
84
|
+
self._cleanup_tasks: list[str] = []
|
|
85
|
+
self._cleanup_tasks_lock = asyncio.Lock()
|
|
86
|
+
self._http_flow_handler: HTTPAuthenticationFlowHandler | None = HTTPAuthenticationFlowHandler()
|
|
87
|
+
|
|
71
88
|
@property
|
|
72
89
|
def config(self) -> AIQConfig:
|
|
73
90
|
return self._config
|
|
74
91
|
|
|
75
92
|
@property
|
|
76
93
|
def front_end_config(self) -> FastApiFrontEndConfig:
|
|
77
|
-
|
|
78
94
|
return self._front_end_config
|
|
79
95
|
|
|
80
96
|
def build_app(self) -> FastAPI:
|
|
@@ -92,17 +108,30 @@ class FastApiFrontEndPluginWorkerBase(ABC):
|
|
|
92
108
|
yield
|
|
93
109
|
|
|
94
110
|
# If a cleanup task is running, cancel it
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
111
|
+
async with self._cleanup_tasks_lock:
|
|
112
|
+
|
|
113
|
+
# Cancel all cleanup tasks
|
|
114
|
+
for task_name in self._cleanup_tasks:
|
|
115
|
+
cleanup_task: asyncio.Task | None = getattr(starting_app.state, task_name, None)
|
|
116
|
+
if cleanup_task is not None:
|
|
117
|
+
logger.info("Cancelling %s cleanup task", task_name)
|
|
118
|
+
cleanup_task.cancel()
|
|
119
|
+
else:
|
|
120
|
+
logger.warning("No cleanup task found for %s", task_name)
|
|
121
|
+
|
|
122
|
+
self._cleanup_tasks.clear()
|
|
99
123
|
|
|
100
124
|
logger.debug("Closing AIQ Toolkit server from process %s", os.getpid())
|
|
101
125
|
|
|
102
126
|
aiq_app = FastAPI(lifespan=lifespan)
|
|
103
127
|
|
|
128
|
+
# Configure app CORS.
|
|
104
129
|
self.set_cors_config(aiq_app)
|
|
105
130
|
|
|
131
|
+
@aiq_app.middleware("http")
|
|
132
|
+
async def authentication_log_filter(request: Request, call_next: Callable[[Request], Awaitable[Response]]):
|
|
133
|
+
return await self._suppress_authentication_logs(request, call_next)
|
|
134
|
+
|
|
106
135
|
return aiq_app
|
|
107
136
|
|
|
108
137
|
def set_cors_config(self, aiq_app: FastAPI) -> None:
|
|
@@ -137,6 +166,26 @@ class FastApiFrontEndPluginWorkerBase(ABC):
|
|
|
137
166
|
**cors_kwargs,
|
|
138
167
|
)
|
|
139
168
|
|
|
169
|
+
async def _suppress_authentication_logs(self, request: Request,
|
|
170
|
+
call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
|
171
|
+
"""
|
|
172
|
+
Intercepts authentication request and supreses logs that contain sensitive data.
|
|
173
|
+
"""
|
|
174
|
+
from aiq.utils.log_utils import LogFilter
|
|
175
|
+
|
|
176
|
+
logs_to_suppress: list[str] = []
|
|
177
|
+
|
|
178
|
+
if (self.front_end_config.oauth2_callback_path):
|
|
179
|
+
logs_to_suppress.append(self.front_end_config.oauth2_callback_path)
|
|
180
|
+
|
|
181
|
+
logging.getLogger("uvicorn.access").addFilter(LogFilter(logs_to_suppress))
|
|
182
|
+
try:
|
|
183
|
+
response = await call_next(request)
|
|
184
|
+
finally:
|
|
185
|
+
logging.getLogger("uvicorn.access").removeFilter(LogFilter(logs_to_suppress))
|
|
186
|
+
|
|
187
|
+
return response
|
|
188
|
+
|
|
140
189
|
@abstractmethod
|
|
141
190
|
async def configure(self, app: FastAPI, builder: WorkflowBuilder):
|
|
142
191
|
pass
|
|
@@ -153,6 +202,38 @@ class RouteInfo(BaseModel):
|
|
|
153
202
|
|
|
154
203
|
class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
155
204
|
|
|
205
|
+
def __init__(self, config: AIQConfig):
|
|
206
|
+
super().__init__(config)
|
|
207
|
+
|
|
208
|
+
self._outstanding_flows: dict[str, FlowState] = {}
|
|
209
|
+
self._outstanding_flows_lock = asyncio.Lock()
|
|
210
|
+
|
|
211
|
+
@staticmethod
|
|
212
|
+
async def _periodic_cleanup(name: str, job_store: JobStore, sleep_time_sec: int = 300):
|
|
213
|
+
while True:
|
|
214
|
+
try:
|
|
215
|
+
job_store.cleanup_expired_jobs()
|
|
216
|
+
logger.debug("Expired %s jobs cleaned up", name)
|
|
217
|
+
except Exception as e:
|
|
218
|
+
logger.error("Error during %s job cleanup: %s", name, e)
|
|
219
|
+
await asyncio.sleep(sleep_time_sec)
|
|
220
|
+
|
|
221
|
+
async def create_cleanup_task(self, app: FastAPI, name: str, job_store: JobStore, sleep_time_sec: int = 300):
|
|
222
|
+
# Schedule periodic cleanup of expired jobs on first job creation
|
|
223
|
+
attr_name = f"{name}_cleanup_task"
|
|
224
|
+
|
|
225
|
+
# Cheap check, if it doesn't exist, we will need to re-check after we acquire the lock
|
|
226
|
+
if not hasattr(app.state, attr_name):
|
|
227
|
+
async with self._cleanup_tasks_lock:
|
|
228
|
+
if not hasattr(app.state, attr_name):
|
|
229
|
+
logger.info("Starting %s periodic cleanup task", name)
|
|
230
|
+
setattr(
|
|
231
|
+
app.state,
|
|
232
|
+
attr_name,
|
|
233
|
+
asyncio.create_task(
|
|
234
|
+
self._periodic_cleanup(name=name, job_store=job_store, sleep_time_sec=sleep_time_sec)))
|
|
235
|
+
self._cleanup_tasks.append(attr_name)
|
|
236
|
+
|
|
156
237
|
def get_step_adaptor(self) -> StepAdaptor:
|
|
157
238
|
|
|
158
239
|
return StepAdaptor(self.front_end_config.step_adaptor)
|
|
@@ -168,6 +249,8 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
168
249
|
|
|
169
250
|
await self.add_default_route(app, AIQSessionManager(builder.build()))
|
|
170
251
|
await self.add_evaluate_route(app, AIQSessionManager(builder.build()))
|
|
252
|
+
await self.add_static_files_route(app, builder)
|
|
253
|
+
await self.add_authorization_route(app)
|
|
171
254
|
|
|
172
255
|
for ep in self.front_end_config.endpoints:
|
|
173
256
|
|
|
@@ -198,21 +281,6 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
198
281
|
# Don't run multiple evaluations at the same time
|
|
199
282
|
evaluation_lock = asyncio.Lock()
|
|
200
283
|
|
|
201
|
-
async def periodic_cleanup(job_store: JobStore):
|
|
202
|
-
while True:
|
|
203
|
-
try:
|
|
204
|
-
job_store.cleanup_expired_jobs()
|
|
205
|
-
logger.debug("Expired jobs cleaned up")
|
|
206
|
-
except Exception as e:
|
|
207
|
-
logger.error("Error during job cleanup: %s", str(e))
|
|
208
|
-
await asyncio.sleep(300) # every 5 minutes
|
|
209
|
-
|
|
210
|
-
def create_cleanup_task():
|
|
211
|
-
# Schedule periodic cleanup of expired jobs on first job creation
|
|
212
|
-
if not hasattr(app.state, "cleanup_task"):
|
|
213
|
-
logger.info("Starting periodic cleanup task")
|
|
214
|
-
app.state.cleanup_task = asyncio.create_task(periodic_cleanup(job_store))
|
|
215
|
-
|
|
216
284
|
async def run_evaluation(job_id: str, config_file: str, reps: int, session_manager: AIQSessionManager):
|
|
217
285
|
"""Background task to run the evaluation."""
|
|
218
286
|
async with evaluation_lock:
|
|
@@ -250,7 +318,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
250
318
|
return AIQEvaluateResponse(job_id=job.job_id, status=job.status)
|
|
251
319
|
|
|
252
320
|
job_id = job_store.create_job(request.config_file, request.job_id, request.expiry_seconds)
|
|
253
|
-
create_cleanup_task()
|
|
321
|
+
await self.create_cleanup_task(app=app, name="async_evaluation", job_store=job_store)
|
|
254
322
|
background_tasks.add_task(run_evaluation, job_id, request.config_file, request.reps, session_manager)
|
|
255
323
|
|
|
256
324
|
return AIQEvaluateResponse(job_id=job_id, status="submitted")
|
|
@@ -276,7 +344,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
276
344
|
if not job:
|
|
277
345
|
logger.warning("Job %s not found", job_id)
|
|
278
346
|
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
|
|
279
|
-
logger.info(
|
|
347
|
+
logger.info("Found job %s with status %s", job_id, job.status)
|
|
280
348
|
return translate_job_to_response(job)
|
|
281
349
|
|
|
282
350
|
async def get_last_job_status(http_request: Request) -> AIQEvaluateStatusResponse:
|
|
@@ -355,6 +423,100 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
355
423
|
responses={500: response_500},
|
|
356
424
|
)
|
|
357
425
|
|
|
426
|
+
async def add_static_files_route(self, app: FastAPI, builder: WorkflowBuilder):
|
|
427
|
+
|
|
428
|
+
if not self.front_end_config.object_store:
|
|
429
|
+
logger.debug("No object store configured, skipping static files route")
|
|
430
|
+
return
|
|
431
|
+
|
|
432
|
+
object_store_client = await builder.get_object_store_client(self.front_end_config.object_store)
|
|
433
|
+
|
|
434
|
+
def sanitize_path(path: str) -> str:
|
|
435
|
+
sanitized_path = os.path.normpath(path.strip("/"))
|
|
436
|
+
if sanitized_path == ".":
|
|
437
|
+
raise HTTPException(status_code=400, detail="Invalid file path.")
|
|
438
|
+
filename = os.path.basename(sanitized_path)
|
|
439
|
+
if not filename:
|
|
440
|
+
raise HTTPException(status_code=400, detail="Filename cannot be empty.")
|
|
441
|
+
return sanitized_path
|
|
442
|
+
|
|
443
|
+
# Upload static files to the object store; if key is present, it will fail with 409 Conflict
|
|
444
|
+
async def add_static_file(file_path: str, file: UploadFile):
|
|
445
|
+
sanitized_file_path = sanitize_path(file_path)
|
|
446
|
+
file_data = await file.read()
|
|
447
|
+
|
|
448
|
+
try:
|
|
449
|
+
await object_store_client.put_object(sanitized_file_path,
|
|
450
|
+
ObjectStoreItem(data=file_data, content_type=file.content_type))
|
|
451
|
+
except KeyAlreadyExistsError as e:
|
|
452
|
+
raise HTTPException(status_code=409, detail=str(e)) from e
|
|
453
|
+
|
|
454
|
+
return {"filename": sanitized_file_path}
|
|
455
|
+
|
|
456
|
+
# Upsert static files to the object store; if key is present, it will overwrite the file
|
|
457
|
+
async def upsert_static_file(file_path: str, file: UploadFile):
|
|
458
|
+
sanitized_file_path = sanitize_path(file_path)
|
|
459
|
+
file_data = await file.read()
|
|
460
|
+
|
|
461
|
+
await object_store_client.upsert_object(sanitized_file_path,
|
|
462
|
+
ObjectStoreItem(data=file_data, content_type=file.content_type))
|
|
463
|
+
|
|
464
|
+
return {"filename": sanitized_file_path}
|
|
465
|
+
|
|
466
|
+
# Get static files from the object store
|
|
467
|
+
async def get_static_file(file_path: str):
|
|
468
|
+
|
|
469
|
+
try:
|
|
470
|
+
file_data = await object_store_client.get_object(file_path)
|
|
471
|
+
except NoSuchKeyError as e:
|
|
472
|
+
raise HTTPException(status_code=404, detail=str(e)) from e
|
|
473
|
+
|
|
474
|
+
filename = file_path.split("/")[-1]
|
|
475
|
+
|
|
476
|
+
async def reader():
|
|
477
|
+
yield file_data.data
|
|
478
|
+
|
|
479
|
+
return StreamingResponse(reader(),
|
|
480
|
+
media_type=file_data.content_type,
|
|
481
|
+
headers={"Content-Disposition": f"attachment; filename={filename}"})
|
|
482
|
+
|
|
483
|
+
async def delete_static_file(file_path: str):
|
|
484
|
+
try:
|
|
485
|
+
await object_store_client.delete_object(file_path)
|
|
486
|
+
except NoSuchKeyError as e:
|
|
487
|
+
raise HTTPException(status_code=404, detail=str(e)) from e
|
|
488
|
+
|
|
489
|
+
return Response(status_code=204)
|
|
490
|
+
|
|
491
|
+
# Add the static files route to the FastAPI app
|
|
492
|
+
app.add_api_route(
|
|
493
|
+
path="/static/{file_path:path}",
|
|
494
|
+
endpoint=add_static_file,
|
|
495
|
+
methods=["POST"],
|
|
496
|
+
description="Upload a static file to the object store",
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
app.add_api_route(
|
|
500
|
+
path="/static/{file_path:path}",
|
|
501
|
+
endpoint=upsert_static_file,
|
|
502
|
+
methods=["PUT"],
|
|
503
|
+
description="Upsert a static file to the object store",
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
app.add_api_route(
|
|
507
|
+
path="/static/{file_path:path}",
|
|
508
|
+
endpoint=get_static_file,
|
|
509
|
+
methods=["GET"],
|
|
510
|
+
description="Get a static file from the object store",
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
app.add_api_route(
|
|
514
|
+
path="/static/{file_path:path}",
|
|
515
|
+
endpoint=delete_static_file,
|
|
516
|
+
methods=["DELETE"],
|
|
517
|
+
description="Delete a static file from the object store",
|
|
518
|
+
)
|
|
519
|
+
|
|
358
520
|
async def add_route(self,
|
|
359
521
|
app: FastAPI,
|
|
360
522
|
endpoint: FastApiFrontEndConfig.EndpointBase,
|
|
@@ -362,17 +524,32 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
362
524
|
|
|
363
525
|
workflow = session_manager.workflow
|
|
364
526
|
|
|
365
|
-
if (endpoint.websocket_path):
|
|
366
|
-
app.add_websocket_route(endpoint.websocket_path,
|
|
367
|
-
partial(AIQWebSocket, session_manager, self.get_step_adaptor()))
|
|
368
|
-
|
|
369
527
|
GenerateBodyType = workflow.input_schema # pylint: disable=invalid-name
|
|
370
528
|
GenerateStreamResponseType = workflow.streaming_output_schema # pylint: disable=invalid-name
|
|
371
529
|
GenerateSingleResponseType = workflow.single_output_schema # pylint: disable=invalid-name
|
|
372
530
|
|
|
531
|
+
# Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
|
|
532
|
+
# Consider prefixing these with "aiq_" to avoid conflicts
|
|
533
|
+
class AIQAsyncGenerateRequest(GenerateBodyType):
|
|
534
|
+
job_id: str | None = Field(default=None, description="Unique identifier for the evaluation job")
|
|
535
|
+
sync_timeout: int = Field(
|
|
536
|
+
default=0,
|
|
537
|
+
ge=0,
|
|
538
|
+
le=300,
|
|
539
|
+
description="Attempt to perform the job synchronously up until `sync_timeout` sectonds, "
|
|
540
|
+
"if the job hasn't been completed by then a job_id will be returned with a status code of 202.")
|
|
541
|
+
expiry_seconds: int = Field(default=JobStore.DEFAULT_EXPIRY,
|
|
542
|
+
ge=JobStore.MIN_EXPIRY,
|
|
543
|
+
le=JobStore.MAX_EXPIRY,
|
|
544
|
+
description="Optional time (in seconds) before the job expires. "
|
|
545
|
+
"Clamped between 600 (10 min) and 86400 (24h).")
|
|
546
|
+
|
|
373
547
|
# Ensure that the input is in the body. POD types are treated as query parameters
|
|
374
548
|
if (not issubclass(GenerateBodyType, BaseModel)):
|
|
375
549
|
GenerateBodyType = typing.Annotated[GenerateBodyType, Body()]
|
|
550
|
+
else:
|
|
551
|
+
logger.info("Expecting generate request payloads in the following format: %s",
|
|
552
|
+
GenerateBodyType.model_fields)
|
|
376
553
|
|
|
377
554
|
response_500 = {
|
|
378
555
|
"description": "Internal Server Error",
|
|
@@ -385,13 +562,20 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
385
562
|
},
|
|
386
563
|
}
|
|
387
564
|
|
|
565
|
+
# Create job store for tracking async generation jobs
|
|
566
|
+
job_store = JobStore()
|
|
567
|
+
|
|
568
|
+
# Run up to max_running_async_jobs jobs at the same time
|
|
569
|
+
async_job_concurrency = asyncio.Semaphore(self._front_end_config.max_running_async_jobs)
|
|
570
|
+
|
|
388
571
|
def get_single_endpoint(result_type: type | None):
|
|
389
572
|
|
|
390
573
|
async def get_single(response: Response, request: Request):
|
|
391
574
|
|
|
392
575
|
response.headers["Content-Type"] = "application/json"
|
|
393
576
|
|
|
394
|
-
async with session_manager.session(request=request
|
|
577
|
+
async with session_manager.session(request=request,
|
|
578
|
+
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
395
579
|
|
|
396
580
|
return await generate_single_response(None, session_manager, result_type=result_type)
|
|
397
581
|
|
|
@@ -401,7 +585,8 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
401
585
|
|
|
402
586
|
async def get_stream(request: Request):
|
|
403
587
|
|
|
404
|
-
async with session_manager.session(request=request
|
|
588
|
+
async with session_manager.session(request=request,
|
|
589
|
+
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
405
590
|
|
|
406
591
|
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
407
592
|
content=generate_streaming_response_as_str(
|
|
@@ -435,7 +620,8 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
435
620
|
|
|
436
621
|
response.headers["Content-Type"] = "application/json"
|
|
437
622
|
|
|
438
|
-
async with session_manager.session(request=request
|
|
623
|
+
async with session_manager.session(request=request,
|
|
624
|
+
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
439
625
|
|
|
440
626
|
return await generate_single_response(payload, session_manager, result_type=result_type)
|
|
441
627
|
|
|
@@ -448,7 +634,8 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
448
634
|
|
|
449
635
|
async def post_stream(request: Request, payload: request_type):
|
|
450
636
|
|
|
451
|
-
async with session_manager.session(request=request
|
|
637
|
+
async with session_manager.session(request=request,
|
|
638
|
+
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
452
639
|
|
|
453
640
|
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
454
641
|
content=generate_streaming_response_as_str(
|
|
@@ -482,7 +669,206 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
482
669
|
|
|
483
670
|
return post_stream
|
|
484
671
|
|
|
672
|
+
def post_openai_api_compatible_endpoint(request_type: type):
|
|
673
|
+
"""
|
|
674
|
+
OpenAI-compatible endpoint that handles both streaming and non-streaming
|
|
675
|
+
based on the 'stream' parameter in the request.
|
|
676
|
+
"""
|
|
677
|
+
|
|
678
|
+
async def post_openai_api_compatible(response: Response, request: Request, payload: request_type):
|
|
679
|
+
# Check if streaming is requested
|
|
680
|
+
stream_requested = getattr(payload, 'stream', False)
|
|
681
|
+
|
|
682
|
+
async with session_manager.session(request=request):
|
|
683
|
+
if stream_requested:
|
|
684
|
+
# Return streaming response
|
|
685
|
+
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
686
|
+
content=generate_streaming_response_as_str(
|
|
687
|
+
payload,
|
|
688
|
+
session_manager=session_manager,
|
|
689
|
+
streaming=True,
|
|
690
|
+
step_adaptor=self.get_step_adaptor(),
|
|
691
|
+
result_type=AIQChatResponseChunk,
|
|
692
|
+
output_type=AIQChatResponseChunk))
|
|
693
|
+
else:
|
|
694
|
+
# Return single response - check if workflow supports non-streaming
|
|
695
|
+
try:
|
|
696
|
+
response.headers["Content-Type"] = "application/json"
|
|
697
|
+
return await generate_single_response(payload, session_manager, result_type=AIQChatResponse)
|
|
698
|
+
except ValueError as e:
|
|
699
|
+
if "Cannot get a single output value for streaming workflows" in str(e):
|
|
700
|
+
# Workflow only supports streaming, but client requested non-streaming
|
|
701
|
+
# Fall back to streaming and collect the result
|
|
702
|
+
chunks = []
|
|
703
|
+
async for chunk_str in generate_streaming_response_as_str(
|
|
704
|
+
payload,
|
|
705
|
+
session_manager=session_manager,
|
|
706
|
+
streaming=True,
|
|
707
|
+
step_adaptor=self.get_step_adaptor(),
|
|
708
|
+
result_type=AIQChatResponseChunk,
|
|
709
|
+
output_type=AIQChatResponseChunk):
|
|
710
|
+
if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
|
|
711
|
+
chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
|
|
712
|
+
if chunk_data:
|
|
713
|
+
try:
|
|
714
|
+
chunk_json = AIQChatResponseChunk.model_validate_json(chunk_data)
|
|
715
|
+
if (chunk_json.choices and len(chunk_json.choices) > 0
|
|
716
|
+
and chunk_json.choices[0].delta
|
|
717
|
+
and chunk_json.choices[0].delta.content is not None):
|
|
718
|
+
chunks.append(chunk_json.choices[0].delta.content)
|
|
719
|
+
except Exception:
|
|
720
|
+
continue
|
|
721
|
+
|
|
722
|
+
# Create a single response from collected chunks
|
|
723
|
+
content = "".join(chunks)
|
|
724
|
+
single_response = AIQChatResponse.from_string(content)
|
|
725
|
+
response.headers["Content-Type"] = "application/json"
|
|
726
|
+
return single_response
|
|
727
|
+
else:
|
|
728
|
+
raise
|
|
729
|
+
|
|
730
|
+
return post_openai_api_compatible
|
|
731
|
+
|
|
732
|
+
async def run_generation(job_id: str,
|
|
733
|
+
payload: typing.Any,
|
|
734
|
+
session_manager: AIQSessionManager,
|
|
735
|
+
result_type: type):
|
|
736
|
+
"""Background task to run the evaluation."""
|
|
737
|
+
async with async_job_concurrency:
|
|
738
|
+
try:
|
|
739
|
+
result = await generate_single_response(payload=payload,
|
|
740
|
+
session_manager=session_manager,
|
|
741
|
+
result_type=result_type)
|
|
742
|
+
job_store.update_status(job_id, "success", output=result)
|
|
743
|
+
except Exception as e:
|
|
744
|
+
logger.error("Error in evaluation job %s: %s", job_id, e)
|
|
745
|
+
job_store.update_status(job_id, "failure", error=str(e))
|
|
746
|
+
|
|
747
|
+
def _job_status_to_response(job: JobInfo) -> AIQAsyncGenerationStatusResponse:
|
|
748
|
+
job_output = job.output
|
|
749
|
+
if job_output is not None:
|
|
750
|
+
job_output = job_output.model_dump()
|
|
751
|
+
return AIQAsyncGenerationStatusResponse(job_id=job.job_id,
|
|
752
|
+
status=job.status,
|
|
753
|
+
error=job.error,
|
|
754
|
+
output=job_output,
|
|
755
|
+
created_at=job.created_at,
|
|
756
|
+
updated_at=job.updated_at,
|
|
757
|
+
expires_at=job_store.get_expires_at(job))
|
|
758
|
+
|
|
759
|
+
def post_async_generation(request_type: type, final_result_type: type):
|
|
760
|
+
|
|
761
|
+
async def start_async_generation(
|
|
762
|
+
request: request_type, background_tasks: BackgroundTasks, response: Response,
|
|
763
|
+
http_request: Request) -> AIQAsyncGenerateResponse | AIQAsyncGenerationStatusResponse:
|
|
764
|
+
"""Handle async generation requests."""
|
|
765
|
+
|
|
766
|
+
async with session_manager.session(request=http_request):
|
|
767
|
+
|
|
768
|
+
# if job_id is present and already exists return the job info
|
|
769
|
+
if request.job_id:
|
|
770
|
+
job = job_store.get_job(request.job_id)
|
|
771
|
+
if job:
|
|
772
|
+
return AIQAsyncGenerateResponse(job_id=job.job_id, status=job.status)
|
|
773
|
+
|
|
774
|
+
job_id = job_store.create_job(job_id=request.job_id, expiry_seconds=request.expiry_seconds)
|
|
775
|
+
await self.create_cleanup_task(app=app, name="async_generation", job_store=job_store)
|
|
776
|
+
|
|
777
|
+
# The fastapi/starlette background tasks won't begin executing until after the response is sent
|
|
778
|
+
# to the client, so we need to wrap the task in a function, alowing us to start the task now,
|
|
779
|
+
# and allowing the background task function to await the results.
|
|
780
|
+
task = asyncio.create_task(
|
|
781
|
+
run_generation(job_id=job_id,
|
|
782
|
+
payload=request,
|
|
783
|
+
session_manager=session_manager,
|
|
784
|
+
result_type=final_result_type))
|
|
785
|
+
|
|
786
|
+
async def wrapped_task(t: asyncio.Task):
|
|
787
|
+
return await t
|
|
788
|
+
|
|
789
|
+
background_tasks.add_task(wrapped_task, task)
|
|
790
|
+
|
|
791
|
+
now = time.time()
|
|
792
|
+
sync_timeout = now + request.sync_timeout
|
|
793
|
+
while time.time() < sync_timeout:
|
|
794
|
+
job = job_store.get_job(job_id)
|
|
795
|
+
if job is not None and job.status not in job_store.ACTIVE_STATUS:
|
|
796
|
+
# If the job is done, return the result
|
|
797
|
+
response.status_code = 200
|
|
798
|
+
return _job_status_to_response(job)
|
|
799
|
+
|
|
800
|
+
# Sleep for a short time before checking again
|
|
801
|
+
await asyncio.sleep(0.1)
|
|
802
|
+
|
|
803
|
+
response.status_code = 202
|
|
804
|
+
return AIQAsyncGenerateResponse(job_id=job_id, status="submitted")
|
|
805
|
+
|
|
806
|
+
return start_async_generation
|
|
807
|
+
|
|
808
|
+
async def get_async_job_status(job_id: str, http_request: Request) -> AIQAsyncGenerationStatusResponse:
|
|
809
|
+
"""Get the status of an async job."""
|
|
810
|
+
logger.info("Getting status for job %s", job_id)
|
|
811
|
+
|
|
812
|
+
async with session_manager.session(request=http_request):
|
|
813
|
+
|
|
814
|
+
job = job_store.get_job(job_id)
|
|
815
|
+
if not job:
|
|
816
|
+
logger.warning("Job %s not found", job_id)
|
|
817
|
+
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
|
|
818
|
+
|
|
819
|
+
logger.info("Found job %s with status %s", job_id, job.status)
|
|
820
|
+
return _job_status_to_response(job)
|
|
821
|
+
|
|
822
|
+
async def websocket_endpoint(websocket: WebSocket):
|
|
823
|
+
|
|
824
|
+
# Universal cookie handling: works for both cross-origin and same-origin connections
|
|
825
|
+
session_id = websocket.query_params.get("session")
|
|
826
|
+
if session_id:
|
|
827
|
+
headers = list(websocket.scope.get("headers", []))
|
|
828
|
+
cookie_header = f"aiqtoolkit-session={session_id}"
|
|
829
|
+
|
|
830
|
+
# Check if the session cookie already exists to avoid duplicates
|
|
831
|
+
cookie_exists = False
|
|
832
|
+
existing_session_cookie = False
|
|
833
|
+
|
|
834
|
+
for i, (name, value) in enumerate(headers):
|
|
835
|
+
if name == b"cookie":
|
|
836
|
+
cookie_exists = True
|
|
837
|
+
cookie_str = value.decode()
|
|
838
|
+
|
|
839
|
+
# Check if aiqtoolkit-session already exists in cookies
|
|
840
|
+
if "aiqtoolkit-session=" in cookie_str:
|
|
841
|
+
existing_session_cookie = True
|
|
842
|
+
logger.info("WebSocket: Session cookie already present in headers (same-origin)")
|
|
843
|
+
else:
|
|
844
|
+
# Append to existing cookie header (cross-origin case)
|
|
845
|
+
headers[i] = (name, f"{cookie_str}; {cookie_header}".encode())
|
|
846
|
+
logger.info("WebSocket: Added session cookie to existing cookie header: %s",
|
|
847
|
+
session_id[:10] + "...")
|
|
848
|
+
break
|
|
849
|
+
|
|
850
|
+
# Add new cookie header only if no cookies exist and no session cookie found
|
|
851
|
+
if not cookie_exists and not existing_session_cookie:
|
|
852
|
+
headers.append((b"cookie", cookie_header.encode()))
|
|
853
|
+
logger.info("WebSocket: Added new session cookie header: %s", session_id[:10] + "...")
|
|
854
|
+
|
|
855
|
+
# Update the websocket scope with the modified headers
|
|
856
|
+
websocket.scope["headers"] = headers
|
|
857
|
+
|
|
858
|
+
async with WebSocketMessageHandler(websocket, session_manager, self.get_step_adaptor()) as handler:
|
|
859
|
+
|
|
860
|
+
flow_handler = WebSocketAuthenticationFlowHandler(self._add_flow, self._remove_flow, handler)
|
|
861
|
+
|
|
862
|
+
# Ugly hack to set the flow handler on the message handler. Both need eachother to be set.
|
|
863
|
+
handler.set_flow_handler(flow_handler)
|
|
864
|
+
|
|
865
|
+
await handler.run()
|
|
866
|
+
|
|
867
|
+
if (endpoint.websocket_path):
|
|
868
|
+
app.add_websocket_route(endpoint.websocket_path, websocket_endpoint)
|
|
869
|
+
|
|
485
870
|
if (endpoint.path):
|
|
871
|
+
|
|
486
872
|
if (endpoint.method == "GET"):
|
|
487
873
|
|
|
488
874
|
app.add_api_route(
|
|
@@ -554,9 +940,31 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
554
940
|
responses={500: response_500},
|
|
555
941
|
)
|
|
556
942
|
|
|
943
|
+
app.add_api_route(
|
|
944
|
+
path=f"{endpoint.path}/async",
|
|
945
|
+
endpoint=post_async_generation(request_type=AIQAsyncGenerateRequest,
|
|
946
|
+
final_result_type=GenerateSingleResponseType),
|
|
947
|
+
methods=[endpoint.method],
|
|
948
|
+
response_model=AIQAsyncGenerateResponse | AIQAsyncGenerationStatusResponse,
|
|
949
|
+
description="Start an async generate job",
|
|
950
|
+
responses={500: response_500},
|
|
951
|
+
)
|
|
557
952
|
else:
|
|
558
953
|
raise ValueError(f"Unsupported method {endpoint.method}")
|
|
559
954
|
|
|
955
|
+
app.add_api_route(
|
|
956
|
+
path=f"{endpoint.path}/async/job/{{job_id}}",
|
|
957
|
+
endpoint=get_async_job_status,
|
|
958
|
+
methods=["GET"],
|
|
959
|
+
response_model=AIQAsyncGenerationStatusResponse,
|
|
960
|
+
description="Get the status of an async job",
|
|
961
|
+
responses={
|
|
962
|
+
404: {
|
|
963
|
+
"description": "Job not found"
|
|
964
|
+
}, 500: response_500
|
|
965
|
+
},
|
|
966
|
+
)
|
|
967
|
+
|
|
560
968
|
if (endpoint.openai_api_path):
|
|
561
969
|
if (endpoint.method == "GET"):
|
|
562
970
|
|
|
@@ -582,26 +990,103 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
582
990
|
|
|
583
991
|
elif (endpoint.method == "POST"):
|
|
584
992
|
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
993
|
+
# Check if OpenAI v1 compatible endpoint is configured
|
|
994
|
+
openai_v1_path = getattr(endpoint, 'openai_api_v1_path', None)
|
|
995
|
+
|
|
996
|
+
# Always create legacy endpoints for backward compatibility (unless they conflict with v1 path)
|
|
997
|
+
if not openai_v1_path or openai_v1_path != endpoint.openai_api_path:
|
|
998
|
+
# <openai_api_path> = non-streaming (legacy behavior)
|
|
999
|
+
app.add_api_route(
|
|
1000
|
+
path=endpoint.openai_api_path,
|
|
1001
|
+
endpoint=post_single_endpoint(request_type=AIQChatRequest, result_type=AIQChatResponse),
|
|
1002
|
+
methods=[endpoint.method],
|
|
1003
|
+
response_model=AIQChatResponse,
|
|
1004
|
+
description=endpoint.description,
|
|
1005
|
+
responses={500: response_500},
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
# <openai_api_path>/stream = streaming (legacy behavior)
|
|
1009
|
+
app.add_api_route(
|
|
1010
|
+
path=f"{endpoint.openai_api_path}/stream",
|
|
1011
|
+
endpoint=post_streaming_endpoint(request_type=AIQChatRequest,
|
|
1012
|
+
streaming=True,
|
|
1013
|
+
result_type=AIQChatResponseChunk,
|
|
1014
|
+
output_type=AIQChatResponseChunk),
|
|
1015
|
+
methods=[endpoint.method],
|
|
1016
|
+
response_model=AIQChatResponseChunk | AIQResponseIntermediateStep,
|
|
1017
|
+
description=endpoint.description,
|
|
1018
|
+
responses={500: response_500},
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
# Create OpenAI v1 compatible endpoint if configured
|
|
1022
|
+
if openai_v1_path:
|
|
1023
|
+
# OpenAI v1 Compatible Mode: Create single endpoint that handles both streaming and non-streaming
|
|
1024
|
+
app.add_api_route(
|
|
1025
|
+
path=openai_v1_path,
|
|
1026
|
+
endpoint=post_openai_api_compatible_endpoint(request_type=AIQChatRequest),
|
|
1027
|
+
methods=[endpoint.method],
|
|
1028
|
+
response_model=AIQChatResponse | AIQChatResponseChunk,
|
|
1029
|
+
description=f"{endpoint.description} (OpenAI Chat Completions API compatible)",
|
|
1030
|
+
responses={500: response_500},
|
|
1031
|
+
)
|
|
605
1032
|
|
|
606
1033
|
else:
|
|
607
1034
|
raise ValueError(f"Unsupported method {endpoint.method}")
|
|
1035
|
+
|
|
1036
|
+
async def add_authorization_route(self, app: FastAPI):
|
|
1037
|
+
|
|
1038
|
+
from fastapi.responses import HTMLResponse
|
|
1039
|
+
|
|
1040
|
+
from aiq.front_ends.fastapi.html_snippets.auth_code_grant_success import AUTH_REDIRECT_SUCCESS_HTML
|
|
1041
|
+
|
|
1042
|
+
async def redirect_uri(request: Request):
|
|
1043
|
+
"""
|
|
1044
|
+
Handle the redirect URI for OAuth2 authentication.
|
|
1045
|
+
Args:
|
|
1046
|
+
request: The FastAPI request object containing query parameters.
|
|
1047
|
+
|
|
1048
|
+
Returns:
|
|
1049
|
+
HTMLResponse: A response indicating the success of the authentication flow.
|
|
1050
|
+
"""
|
|
1051
|
+
state = request.query_params.get("state")
|
|
1052
|
+
|
|
1053
|
+
async with self._outstanding_flows_lock:
|
|
1054
|
+
if not state or state not in self._outstanding_flows:
|
|
1055
|
+
return "Invalid state. Please restart the authentication process."
|
|
1056
|
+
|
|
1057
|
+
flow_state = self._outstanding_flows[state]
|
|
1058
|
+
|
|
1059
|
+
config = flow_state.config
|
|
1060
|
+
verifier = flow_state.verifier
|
|
1061
|
+
client = flow_state.client
|
|
1062
|
+
|
|
1063
|
+
try:
|
|
1064
|
+
res = await client.fetch_token(url=config.token_url,
|
|
1065
|
+
authorization_response=str(request.url),
|
|
1066
|
+
code_verifier=verifier,
|
|
1067
|
+
state=state)
|
|
1068
|
+
flow_state.future.set_result(res)
|
|
1069
|
+
except Exception as e:
|
|
1070
|
+
flow_state.future.set_exception(e)
|
|
1071
|
+
|
|
1072
|
+
return HTMLResponse(content=AUTH_REDIRECT_SUCCESS_HTML,
|
|
1073
|
+
status_code=200,
|
|
1074
|
+
headers={
|
|
1075
|
+
"Content-Type": "text/html; charset=utf-8", "Cache-Control": "no-cache"
|
|
1076
|
+
})
|
|
1077
|
+
|
|
1078
|
+
if (self.front_end_config.oauth2_callback_path):
|
|
1079
|
+
# Add the redirect URI route
|
|
1080
|
+
app.add_api_route(
|
|
1081
|
+
path=self.front_end_config.oauth2_callback_path,
|
|
1082
|
+
endpoint=redirect_uri,
|
|
1083
|
+
methods=["GET"],
|
|
1084
|
+
description="Handles the authorization code and state returned from the Authorization Code Grant Flow.")
|
|
1085
|
+
|
|
1086
|
+
async def _add_flow(self, state: str, flow_state: FlowState):
|
|
1087
|
+
async with self._outstanding_flows_lock:
|
|
1088
|
+
self._outstanding_flows[state] = flow_state
|
|
1089
|
+
|
|
1090
|
+
async def _remove_flow(self, state: str):
|
|
1091
|
+
async with self._outstanding_flows_lock:
|
|
1092
|
+
del self._outstanding_flows[state]
|