aiqtoolkit 1.2.0.dev0__py3-none-any.whl → 1.2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aiqtoolkit might be problematic. Click here for more details.

Files changed (220) hide show
  1. aiq/agent/base.py +170 -8
  2. aiq/agent/dual_node.py +1 -1
  3. aiq/agent/react_agent/agent.py +146 -112
  4. aiq/agent/react_agent/prompt.py +1 -6
  5. aiq/agent/react_agent/register.py +36 -35
  6. aiq/agent/rewoo_agent/agent.py +36 -35
  7. aiq/agent/rewoo_agent/register.py +2 -2
  8. aiq/agent/tool_calling_agent/agent.py +3 -7
  9. aiq/agent/tool_calling_agent/register.py +1 -1
  10. aiq/authentication/__init__.py +14 -0
  11. aiq/authentication/api_key/__init__.py +14 -0
  12. aiq/authentication/api_key/api_key_auth_provider.py +92 -0
  13. aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
  14. aiq/authentication/api_key/register.py +26 -0
  15. aiq/authentication/exceptions/__init__.py +14 -0
  16. aiq/authentication/exceptions/api_key_exceptions.py +38 -0
  17. aiq/authentication/exceptions/auth_code_grant_exceptions.py +86 -0
  18. aiq/authentication/exceptions/call_back_exceptions.py +38 -0
  19. aiq/authentication/exceptions/request_exceptions.py +54 -0
  20. aiq/authentication/http_basic_auth/__init__.py +0 -0
  21. aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  22. aiq/authentication/http_basic_auth/register.py +30 -0
  23. aiq/authentication/interfaces.py +93 -0
  24. aiq/authentication/oauth2/__init__.py +14 -0
  25. aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
  26. aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  27. aiq/authentication/oauth2/register.py +25 -0
  28. aiq/authentication/register.py +21 -0
  29. aiq/builder/builder.py +64 -2
  30. aiq/builder/component_utils.py +16 -3
  31. aiq/builder/context.py +37 -0
  32. aiq/builder/eval_builder.py +43 -2
  33. aiq/builder/function.py +44 -12
  34. aiq/builder/function_base.py +1 -1
  35. aiq/builder/intermediate_step_manager.py +6 -8
  36. aiq/builder/user_interaction_manager.py +3 -0
  37. aiq/builder/workflow.py +23 -18
  38. aiq/builder/workflow_builder.py +421 -61
  39. aiq/cli/commands/info/list_mcp.py +103 -16
  40. aiq/cli/commands/sizing/__init__.py +14 -0
  41. aiq/cli/commands/sizing/calc.py +294 -0
  42. aiq/cli/commands/sizing/sizing.py +27 -0
  43. aiq/cli/commands/start.py +2 -1
  44. aiq/cli/entrypoint.py +2 -0
  45. aiq/cli/register_workflow.py +80 -0
  46. aiq/cli/type_registry.py +151 -30
  47. aiq/data_models/api_server.py +124 -12
  48. aiq/data_models/authentication.py +231 -0
  49. aiq/data_models/common.py +35 -7
  50. aiq/data_models/component.py +17 -9
  51. aiq/data_models/component_ref.py +33 -0
  52. aiq/data_models/config.py +60 -3
  53. aiq/data_models/dataset_handler.py +2 -1
  54. aiq/data_models/embedder.py +1 -0
  55. aiq/data_models/evaluate.py +23 -0
  56. aiq/data_models/function_dependencies.py +8 -0
  57. aiq/data_models/interactive.py +10 -1
  58. aiq/data_models/intermediate_step.py +38 -5
  59. aiq/data_models/its_strategy.py +30 -0
  60. aiq/data_models/llm.py +1 -0
  61. aiq/data_models/memory.py +1 -0
  62. aiq/data_models/object_store.py +44 -0
  63. aiq/data_models/profiler.py +1 -0
  64. aiq/data_models/retry_mixin.py +35 -0
  65. aiq/data_models/span.py +187 -0
  66. aiq/data_models/telemetry_exporter.py +2 -2
  67. aiq/embedder/nim_embedder.py +2 -1
  68. aiq/embedder/openai_embedder.py +2 -1
  69. aiq/eval/config.py +19 -1
  70. aiq/eval/dataset_handler/dataset_handler.py +87 -2
  71. aiq/eval/evaluate.py +208 -27
  72. aiq/eval/evaluator/base_evaluator.py +73 -0
  73. aiq/eval/evaluator/evaluator_model.py +1 -0
  74. aiq/eval/intermediate_step_adapter.py +11 -5
  75. aiq/eval/rag_evaluator/evaluate.py +55 -15
  76. aiq/eval/rag_evaluator/register.py +6 -1
  77. aiq/eval/remote_workflow.py +7 -2
  78. aiq/eval/runners/__init__.py +14 -0
  79. aiq/eval/runners/config.py +39 -0
  80. aiq/eval/runners/multi_eval_runner.py +54 -0
  81. aiq/eval/trajectory_evaluator/evaluate.py +22 -65
  82. aiq/eval/tunable_rag_evaluator/evaluate.py +150 -168
  83. aiq/eval/tunable_rag_evaluator/register.py +2 -0
  84. aiq/eval/usage_stats.py +41 -0
  85. aiq/eval/utils/output_uploader.py +10 -1
  86. aiq/eval/utils/weave_eval.py +184 -0
  87. aiq/experimental/__init__.py +0 -0
  88. aiq/experimental/decorators/__init__.py +0 -0
  89. aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
  90. aiq/experimental/inference_time_scaling/__init__.py +0 -0
  91. aiq/experimental/inference_time_scaling/editing/__init__.py +0 -0
  92. aiq/experimental/inference_time_scaling/editing/iterative_plan_refinement_editor.py +147 -0
  93. aiq/experimental/inference_time_scaling/editing/llm_as_a_judge_editor.py +204 -0
  94. aiq/experimental/inference_time_scaling/editing/motivation_aware_summarization.py +107 -0
  95. aiq/experimental/inference_time_scaling/functions/__init__.py +0 -0
  96. aiq/experimental/inference_time_scaling/functions/execute_score_select_function.py +105 -0
  97. aiq/experimental/inference_time_scaling/functions/its_tool_orchestration_function.py +205 -0
  98. aiq/experimental/inference_time_scaling/functions/its_tool_wrapper_function.py +146 -0
  99. aiq/experimental/inference_time_scaling/functions/plan_select_execute_function.py +224 -0
  100. aiq/experimental/inference_time_scaling/models/__init__.py +0 -0
  101. aiq/experimental/inference_time_scaling/models/editor_config.py +132 -0
  102. aiq/experimental/inference_time_scaling/models/its_item.py +48 -0
  103. aiq/experimental/inference_time_scaling/models/scoring_config.py +112 -0
  104. aiq/experimental/inference_time_scaling/models/search_config.py +120 -0
  105. aiq/experimental/inference_time_scaling/models/selection_config.py +154 -0
  106. aiq/experimental/inference_time_scaling/models/stage_enums.py +43 -0
  107. aiq/experimental/inference_time_scaling/models/strategy_base.py +66 -0
  108. aiq/experimental/inference_time_scaling/models/tool_use_config.py +41 -0
  109. aiq/experimental/inference_time_scaling/register.py +36 -0
  110. aiq/experimental/inference_time_scaling/scoring/__init__.py +0 -0
  111. aiq/experimental/inference_time_scaling/scoring/llm_based_agent_scorer.py +168 -0
  112. aiq/experimental/inference_time_scaling/scoring/llm_based_plan_scorer.py +168 -0
  113. aiq/experimental/inference_time_scaling/scoring/motivation_aware_scorer.py +111 -0
  114. aiq/experimental/inference_time_scaling/search/__init__.py +0 -0
  115. aiq/experimental/inference_time_scaling/search/multi_llm_planner.py +128 -0
  116. aiq/experimental/inference_time_scaling/search/multi_query_retrieval_search.py +122 -0
  117. aiq/experimental/inference_time_scaling/search/single_shot_multi_plan_planner.py +128 -0
  118. aiq/experimental/inference_time_scaling/selection/__init__.py +0 -0
  119. aiq/experimental/inference_time_scaling/selection/best_of_n_selector.py +63 -0
  120. aiq/experimental/inference_time_scaling/selection/llm_based_agent_output_selector.py +131 -0
  121. aiq/experimental/inference_time_scaling/selection/llm_based_output_merging_selector.py +159 -0
  122. aiq/experimental/inference_time_scaling/selection/llm_based_plan_selector.py +128 -0
  123. aiq/experimental/inference_time_scaling/selection/threshold_selector.py +58 -0
  124. aiq/front_ends/console/authentication_flow_handler.py +233 -0
  125. aiq/front_ends/console/console_front_end_plugin.py +11 -2
  126. aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  127. aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  128. aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
  129. aiq/front_ends/fastapi/fastapi_front_end_config.py +93 -9
  130. aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  131. aiq/front_ends/fastapi/fastapi_front_end_plugin.py +14 -1
  132. aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +537 -52
  133. aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
  134. aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  135. aiq/front_ends/fastapi/job_store.py +47 -25
  136. aiq/front_ends/fastapi/main.py +2 -0
  137. aiq/front_ends/fastapi/message_handler.py +108 -89
  138. aiq/front_ends/fastapi/step_adaptor.py +2 -1
  139. aiq/llm/aws_bedrock_llm.py +57 -0
  140. aiq/llm/nim_llm.py +2 -1
  141. aiq/llm/openai_llm.py +3 -2
  142. aiq/llm/register.py +1 -0
  143. aiq/meta/pypi.md +12 -12
  144. aiq/object_store/__init__.py +20 -0
  145. aiq/object_store/in_memory_object_store.py +74 -0
  146. aiq/object_store/interfaces.py +84 -0
  147. aiq/object_store/models.py +36 -0
  148. aiq/object_store/register.py +20 -0
  149. aiq/observability/__init__.py +14 -0
  150. aiq/observability/exporter/__init__.py +14 -0
  151. aiq/observability/exporter/base_exporter.py +449 -0
  152. aiq/observability/exporter/exporter.py +78 -0
  153. aiq/observability/exporter/file_exporter.py +33 -0
  154. aiq/observability/exporter/processing_exporter.py +269 -0
  155. aiq/observability/exporter/raw_exporter.py +52 -0
  156. aiq/observability/exporter/span_exporter.py +264 -0
  157. aiq/observability/exporter_manager.py +335 -0
  158. aiq/observability/mixin/__init__.py +14 -0
  159. aiq/observability/mixin/batch_config_mixin.py +26 -0
  160. aiq/observability/mixin/collector_config_mixin.py +23 -0
  161. aiq/observability/mixin/file_mixin.py +288 -0
  162. aiq/observability/mixin/file_mode.py +23 -0
  163. aiq/observability/mixin/resource_conflict_mixin.py +134 -0
  164. aiq/observability/mixin/serialize_mixin.py +61 -0
  165. aiq/observability/mixin/type_introspection_mixin.py +183 -0
  166. aiq/observability/processor/__init__.py +14 -0
  167. aiq/observability/processor/batching_processor.py +316 -0
  168. aiq/observability/processor/intermediate_step_serializer.py +28 -0
  169. aiq/observability/processor/processor.py +68 -0
  170. aiq/observability/register.py +36 -39
  171. aiq/observability/utils/__init__.py +14 -0
  172. aiq/observability/utils/dict_utils.py +236 -0
  173. aiq/observability/utils/time_utils.py +31 -0
  174. aiq/profiler/calc/__init__.py +14 -0
  175. aiq/profiler/calc/calc_runner.py +623 -0
  176. aiq/profiler/calc/calculations.py +288 -0
  177. aiq/profiler/calc/data_models.py +176 -0
  178. aiq/profiler/calc/plot.py +345 -0
  179. aiq/profiler/callbacks/langchain_callback_handler.py +22 -10
  180. aiq/profiler/data_models.py +24 -0
  181. aiq/profiler/inference_metrics_model.py +3 -0
  182. aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +8 -0
  183. aiq/profiler/inference_optimization/data_models.py +2 -2
  184. aiq/profiler/inference_optimization/llm_metrics.py +2 -2
  185. aiq/profiler/profile_runner.py +61 -21
  186. aiq/runtime/loader.py +9 -3
  187. aiq/runtime/runner.py +23 -9
  188. aiq/runtime/session.py +25 -7
  189. aiq/runtime/user_metadata.py +2 -3
  190. aiq/tool/chat_completion.py +74 -0
  191. aiq/tool/code_execution/README.md +152 -0
  192. aiq/tool/code_execution/code_sandbox.py +151 -72
  193. aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
  194. aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +139 -24
  195. aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +3 -1
  196. aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +27 -2
  197. aiq/tool/code_execution/register.py +7 -3
  198. aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
  199. aiq/tool/mcp/exceptions.py +142 -0
  200. aiq/tool/mcp/mcp_client.py +41 -6
  201. aiq/tool/mcp/mcp_tool.py +3 -2
  202. aiq/tool/register.py +1 -0
  203. aiq/tool/server_tools.py +6 -3
  204. aiq/utils/exception_handlers/automatic_retries.py +289 -0
  205. aiq/utils/exception_handlers/mcp.py +211 -0
  206. aiq/utils/io/model_processing.py +28 -0
  207. aiq/utils/log_utils.py +37 -0
  208. aiq/utils/string_utils.py +38 -0
  209. aiq/utils/type_converter.py +18 -2
  210. aiq/utils/type_utils.py +87 -0
  211. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/METADATA +53 -21
  212. aiqtoolkit-1.2.0rc1.dist-info/RECORD +436 -0
  213. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/WHEEL +1 -1
  214. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/entry_points.txt +3 -0
  215. aiq/front_ends/fastapi/websocket.py +0 -148
  216. aiq/observability/async_otel_listener.py +0 -429
  217. aiqtoolkit-1.2.0.dev0.dist-info/RECORD +0 -316
  218. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  219. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/licenses/LICENSE.md +0 -0
  220. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/top_level.txt +0 -0
@@ -16,11 +16,13 @@
16
16
  import asyncio
17
17
  import logging
18
18
  import os
19
+ import time
19
20
  import typing
20
21
  from abc import ABC
21
22
  from abc import abstractmethod
23
+ from collections.abc import Awaitable
24
+ from collections.abc import Callable
22
25
  from contextlib import asynccontextmanager
23
- from functools import partial
24
26
  from pathlib import Path
25
27
 
26
28
  from fastapi import BackgroundTasks
@@ -28,10 +30,13 @@ from fastapi import Body
28
30
  from fastapi import FastAPI
29
31
  from fastapi import Request
30
32
  from fastapi import Response
33
+ from fastapi import UploadFile
31
34
  from fastapi.exceptions import HTTPException
32
35
  from fastapi.middleware.cors import CORSMiddleware
33
36
  from fastapi.responses import StreamingResponse
34
37
  from pydantic import BaseModel
38
+ from pydantic import Field
39
+ from starlette.websockets import WebSocket
35
40
 
36
41
  from aiq.builder.workflow_builder import WorkflowBuilder
37
42
  from aiq.data_models.api_server import AIQChatRequest
@@ -39,20 +44,28 @@ from aiq.data_models.api_server import AIQChatResponse
39
44
  from aiq.data_models.api_server import AIQChatResponseChunk
40
45
  from aiq.data_models.api_server import AIQResponseIntermediateStep
41
46
  from aiq.data_models.config import AIQConfig
47
+ from aiq.data_models.object_store import KeyAlreadyExistsError
48
+ from aiq.data_models.object_store import NoSuchKeyError
42
49
  from aiq.eval.config import EvaluationRunOutput
43
50
  from aiq.eval.evaluate import EvaluationRun
44
51
  from aiq.eval.evaluate import EvaluationRunConfig
52
+ from aiq.front_ends.fastapi.auth_flow_handlers.http_flow_handler import HTTPAuthenticationFlowHandler
53
+ from aiq.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import FlowState
54
+ from aiq.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import WebSocketAuthenticationFlowHandler
55
+ from aiq.front_ends.fastapi.fastapi_front_end_config import AIQAsyncGenerateResponse
56
+ from aiq.front_ends.fastapi.fastapi_front_end_config import AIQAsyncGenerationStatusResponse
45
57
  from aiq.front_ends.fastapi.fastapi_front_end_config import AIQEvaluateRequest
46
58
  from aiq.front_ends.fastapi.fastapi_front_end_config import AIQEvaluateResponse
47
59
  from aiq.front_ends.fastapi.fastapi_front_end_config import AIQEvaluateStatusResponse
48
60
  from aiq.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
49
61
  from aiq.front_ends.fastapi.job_store import JobInfo
50
62
  from aiq.front_ends.fastapi.job_store import JobStore
63
+ from aiq.front_ends.fastapi.message_handler import WebSocketMessageHandler
51
64
  from aiq.front_ends.fastapi.response_helpers import generate_single_response
52
65
  from aiq.front_ends.fastapi.response_helpers import generate_streaming_response_as_str
53
66
  from aiq.front_ends.fastapi.response_helpers import generate_streaming_response_full_as_str
54
67
  from aiq.front_ends.fastapi.step_adaptor import StepAdaptor
55
- from aiq.front_ends.fastapi.websocket import AIQWebSocket
68
+ from aiq.object_store.models import ObjectStoreItem
56
69
  from aiq.runtime.session import AIQSessionManager
57
70
 
58
71
  logger = logging.getLogger(__name__)
@@ -68,13 +81,16 @@ class FastApiFrontEndPluginWorkerBase(ABC):
68
81
 
69
82
  self._front_end_config = config.general.front_end
70
83
 
84
+ self._cleanup_tasks: list[str] = []
85
+ self._cleanup_tasks_lock = asyncio.Lock()
86
+ self._http_flow_handler: HTTPAuthenticationFlowHandler | None = HTTPAuthenticationFlowHandler()
87
+
71
88
  @property
72
89
  def config(self) -> AIQConfig:
73
90
  return self._config
74
91
 
75
92
  @property
76
93
  def front_end_config(self) -> FastApiFrontEndConfig:
77
-
78
94
  return self._front_end_config
79
95
 
80
96
  def build_app(self) -> FastAPI:
@@ -92,17 +108,30 @@ class FastApiFrontEndPluginWorkerBase(ABC):
92
108
  yield
93
109
 
94
110
  # If a cleanup task is running, cancel it
95
- cleanup_task = getattr(starting_app.state, "cleanup_task", None)
96
- if cleanup_task:
97
- logger.info("Cancelling cleanup task")
98
- cleanup_task.cancel()
111
+ async with self._cleanup_tasks_lock:
112
+
113
+ # Cancel all cleanup tasks
114
+ for task_name in self._cleanup_tasks:
115
+ cleanup_task: asyncio.Task | None = getattr(starting_app.state, task_name, None)
116
+ if cleanup_task is not None:
117
+ logger.info("Cancelling %s cleanup task", task_name)
118
+ cleanup_task.cancel()
119
+ else:
120
+ logger.warning("No cleanup task found for %s", task_name)
121
+
122
+ self._cleanup_tasks.clear()
99
123
 
100
124
  logger.debug("Closing AIQ Toolkit server from process %s", os.getpid())
101
125
 
102
126
  aiq_app = FastAPI(lifespan=lifespan)
103
127
 
128
+ # Configure app CORS.
104
129
  self.set_cors_config(aiq_app)
105
130
 
131
+ @aiq_app.middleware("http")
132
+ async def authentication_log_filter(request: Request, call_next: Callable[[Request], Awaitable[Response]]):
133
+ return await self._suppress_authentication_logs(request, call_next)
134
+
106
135
  return aiq_app
107
136
 
108
137
  def set_cors_config(self, aiq_app: FastAPI) -> None:
@@ -137,6 +166,26 @@ class FastApiFrontEndPluginWorkerBase(ABC):
137
166
  **cors_kwargs,
138
167
  )
139
168
 
169
+ async def _suppress_authentication_logs(self, request: Request,
170
+ call_next: Callable[[Request], Awaitable[Response]]) -> Response:
171
+ """
172
+ Intercepts authentication request and supreses logs that contain sensitive data.
173
+ """
174
+ from aiq.utils.log_utils import LogFilter
175
+
176
+ logs_to_suppress: list[str] = []
177
+
178
+ if (self.front_end_config.oauth2_callback_path):
179
+ logs_to_suppress.append(self.front_end_config.oauth2_callback_path)
180
+
181
+ logging.getLogger("uvicorn.access").addFilter(LogFilter(logs_to_suppress))
182
+ try:
183
+ response = await call_next(request)
184
+ finally:
185
+ logging.getLogger("uvicorn.access").removeFilter(LogFilter(logs_to_suppress))
186
+
187
+ return response
188
+
140
189
  @abstractmethod
141
190
  async def configure(self, app: FastAPI, builder: WorkflowBuilder):
142
191
  pass
@@ -153,6 +202,38 @@ class RouteInfo(BaseModel):
153
202
 
154
203
  class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
155
204
 
205
+ def __init__(self, config: AIQConfig):
206
+ super().__init__(config)
207
+
208
+ self._outstanding_flows: dict[str, FlowState] = {}
209
+ self._outstanding_flows_lock = asyncio.Lock()
210
+
211
+ @staticmethod
212
+ async def _periodic_cleanup(name: str, job_store: JobStore, sleep_time_sec: int = 300):
213
+ while True:
214
+ try:
215
+ job_store.cleanup_expired_jobs()
216
+ logger.debug("Expired %s jobs cleaned up", name)
217
+ except Exception as e:
218
+ logger.error("Error during %s job cleanup: %s", name, e)
219
+ await asyncio.sleep(sleep_time_sec)
220
+
221
+ async def create_cleanup_task(self, app: FastAPI, name: str, job_store: JobStore, sleep_time_sec: int = 300):
222
+ # Schedule periodic cleanup of expired jobs on first job creation
223
+ attr_name = f"{name}_cleanup_task"
224
+
225
+ # Cheap check, if it doesn't exist, we will need to re-check after we acquire the lock
226
+ if not hasattr(app.state, attr_name):
227
+ async with self._cleanup_tasks_lock:
228
+ if not hasattr(app.state, attr_name):
229
+ logger.info("Starting %s periodic cleanup task", name)
230
+ setattr(
231
+ app.state,
232
+ attr_name,
233
+ asyncio.create_task(
234
+ self._periodic_cleanup(name=name, job_store=job_store, sleep_time_sec=sleep_time_sec)))
235
+ self._cleanup_tasks.append(attr_name)
236
+
156
237
  def get_step_adaptor(self) -> StepAdaptor:
157
238
 
158
239
  return StepAdaptor(self.front_end_config.step_adaptor)
@@ -168,6 +249,8 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
168
249
 
169
250
  await self.add_default_route(app, AIQSessionManager(builder.build()))
170
251
  await self.add_evaluate_route(app, AIQSessionManager(builder.build()))
252
+ await self.add_static_files_route(app, builder)
253
+ await self.add_authorization_route(app)
171
254
 
172
255
  for ep in self.front_end_config.endpoints:
173
256
 
@@ -198,21 +281,6 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
198
281
  # Don't run multiple evaluations at the same time
199
282
  evaluation_lock = asyncio.Lock()
200
283
 
201
- async def periodic_cleanup(job_store: JobStore):
202
- while True:
203
- try:
204
- job_store.cleanup_expired_jobs()
205
- logger.debug("Expired jobs cleaned up")
206
- except Exception as e:
207
- logger.error("Error during job cleanup: %s", str(e))
208
- await asyncio.sleep(300) # every 5 minutes
209
-
210
- def create_cleanup_task():
211
- # Schedule periodic cleanup of expired jobs on first job creation
212
- if not hasattr(app.state, "cleanup_task"):
213
- logger.info("Starting periodic cleanup task")
214
- app.state.cleanup_task = asyncio.create_task(periodic_cleanup(job_store))
215
-
216
284
  async def run_evaluation(job_id: str, config_file: str, reps: int, session_manager: AIQSessionManager):
217
285
  """Background task to run the evaluation."""
218
286
  async with evaluation_lock:
@@ -250,7 +318,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
250
318
  return AIQEvaluateResponse(job_id=job.job_id, status=job.status)
251
319
 
252
320
  job_id = job_store.create_job(request.config_file, request.job_id, request.expiry_seconds)
253
- create_cleanup_task()
321
+ await self.create_cleanup_task(app=app, name="async_evaluation", job_store=job_store)
254
322
  background_tasks.add_task(run_evaluation, job_id, request.config_file, request.reps, session_manager)
255
323
 
256
324
  return AIQEvaluateResponse(job_id=job_id, status="submitted")
@@ -276,7 +344,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
276
344
  if not job:
277
345
  logger.warning("Job %s not found", job_id)
278
346
  raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
279
- logger.info(f"Found job {job_id} with status {job.status}")
347
+ logger.info("Found job %s with status %s", job_id, job.status)
280
348
  return translate_job_to_response(job)
281
349
 
282
350
  async def get_last_job_status(http_request: Request) -> AIQEvaluateStatusResponse:
@@ -355,6 +423,100 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
355
423
  responses={500: response_500},
356
424
  )
357
425
 
426
+ async def add_static_files_route(self, app: FastAPI, builder: WorkflowBuilder):
427
+
428
+ if not self.front_end_config.object_store:
429
+ logger.debug("No object store configured, skipping static files route")
430
+ return
431
+
432
+ object_store_client = await builder.get_object_store_client(self.front_end_config.object_store)
433
+
434
+ def sanitize_path(path: str) -> str:
435
+ sanitized_path = os.path.normpath(path.strip("/"))
436
+ if sanitized_path == ".":
437
+ raise HTTPException(status_code=400, detail="Invalid file path.")
438
+ filename = os.path.basename(sanitized_path)
439
+ if not filename:
440
+ raise HTTPException(status_code=400, detail="Filename cannot be empty.")
441
+ return sanitized_path
442
+
443
+ # Upload static files to the object store; if key is present, it will fail with 409 Conflict
444
+ async def add_static_file(file_path: str, file: UploadFile):
445
+ sanitized_file_path = sanitize_path(file_path)
446
+ file_data = await file.read()
447
+
448
+ try:
449
+ await object_store_client.put_object(sanitized_file_path,
450
+ ObjectStoreItem(data=file_data, content_type=file.content_type))
451
+ except KeyAlreadyExistsError as e:
452
+ raise HTTPException(status_code=409, detail=str(e)) from e
453
+
454
+ return {"filename": sanitized_file_path}
455
+
456
+ # Upsert static files to the object store; if key is present, it will overwrite the file
457
+ async def upsert_static_file(file_path: str, file: UploadFile):
458
+ sanitized_file_path = sanitize_path(file_path)
459
+ file_data = await file.read()
460
+
461
+ await object_store_client.upsert_object(sanitized_file_path,
462
+ ObjectStoreItem(data=file_data, content_type=file.content_type))
463
+
464
+ return {"filename": sanitized_file_path}
465
+
466
+ # Get static files from the object store
467
+ async def get_static_file(file_path: str):
468
+
469
+ try:
470
+ file_data = await object_store_client.get_object(file_path)
471
+ except NoSuchKeyError as e:
472
+ raise HTTPException(status_code=404, detail=str(e)) from e
473
+
474
+ filename = file_path.split("/")[-1]
475
+
476
+ async def reader():
477
+ yield file_data.data
478
+
479
+ return StreamingResponse(reader(),
480
+ media_type=file_data.content_type,
481
+ headers={"Content-Disposition": f"attachment; filename={filename}"})
482
+
483
+ async def delete_static_file(file_path: str):
484
+ try:
485
+ await object_store_client.delete_object(file_path)
486
+ except NoSuchKeyError as e:
487
+ raise HTTPException(status_code=404, detail=str(e)) from e
488
+
489
+ return Response(status_code=204)
490
+
491
+ # Add the static files route to the FastAPI app
492
+ app.add_api_route(
493
+ path="/static/{file_path:path}",
494
+ endpoint=add_static_file,
495
+ methods=["POST"],
496
+ description="Upload a static file to the object store",
497
+ )
498
+
499
+ app.add_api_route(
500
+ path="/static/{file_path:path}",
501
+ endpoint=upsert_static_file,
502
+ methods=["PUT"],
503
+ description="Upsert a static file to the object store",
504
+ )
505
+
506
+ app.add_api_route(
507
+ path="/static/{file_path:path}",
508
+ endpoint=get_static_file,
509
+ methods=["GET"],
510
+ description="Get a static file from the object store",
511
+ )
512
+
513
+ app.add_api_route(
514
+ path="/static/{file_path:path}",
515
+ endpoint=delete_static_file,
516
+ methods=["DELETE"],
517
+ description="Delete a static file from the object store",
518
+ )
519
+
358
520
  async def add_route(self,
359
521
  app: FastAPI,
360
522
  endpoint: FastApiFrontEndConfig.EndpointBase,
@@ -362,17 +524,32 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
362
524
 
363
525
  workflow = session_manager.workflow
364
526
 
365
- if (endpoint.websocket_path):
366
- app.add_websocket_route(endpoint.websocket_path,
367
- partial(AIQWebSocket, session_manager, self.get_step_adaptor()))
368
-
369
527
  GenerateBodyType = workflow.input_schema # pylint: disable=invalid-name
370
528
  GenerateStreamResponseType = workflow.streaming_output_schema # pylint: disable=invalid-name
371
529
  GenerateSingleResponseType = workflow.single_output_schema # pylint: disable=invalid-name
372
530
 
531
+ # Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
532
+ # Consider prefixing these with "aiq_" to avoid conflicts
533
+ class AIQAsyncGenerateRequest(GenerateBodyType):
534
+ job_id: str | None = Field(default=None, description="Unique identifier for the evaluation job")
535
+ sync_timeout: int = Field(
536
+ default=0,
537
+ ge=0,
538
+ le=300,
539
+ description="Attempt to perform the job synchronously up until `sync_timeout` sectonds, "
540
+ "if the job hasn't been completed by then a job_id will be returned with a status code of 202.")
541
+ expiry_seconds: int = Field(default=JobStore.DEFAULT_EXPIRY,
542
+ ge=JobStore.MIN_EXPIRY,
543
+ le=JobStore.MAX_EXPIRY,
544
+ description="Optional time (in seconds) before the job expires. "
545
+ "Clamped between 600 (10 min) and 86400 (24h).")
546
+
373
547
  # Ensure that the input is in the body. POD types are treated as query parameters
374
548
  if (not issubclass(GenerateBodyType, BaseModel)):
375
549
  GenerateBodyType = typing.Annotated[GenerateBodyType, Body()]
550
+ else:
551
+ logger.info("Expecting generate request payloads in the following format: %s",
552
+ GenerateBodyType.model_fields)
376
553
 
377
554
  response_500 = {
378
555
  "description": "Internal Server Error",
@@ -385,13 +562,20 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
385
562
  },
386
563
  }
387
564
 
565
+ # Create job store for tracking async generation jobs
566
+ job_store = JobStore()
567
+
568
+ # Run up to max_running_async_jobs jobs at the same time
569
+ async_job_concurrency = asyncio.Semaphore(self._front_end_config.max_running_async_jobs)
570
+
388
571
  def get_single_endpoint(result_type: type | None):
389
572
 
390
573
  async def get_single(response: Response, request: Request):
391
574
 
392
575
  response.headers["Content-Type"] = "application/json"
393
576
 
394
- async with session_manager.session(request=request):
577
+ async with session_manager.session(request=request,
578
+ user_authentication_callback=self._http_flow_handler.authenticate):
395
579
 
396
580
  return await generate_single_response(None, session_manager, result_type=result_type)
397
581
 
@@ -401,7 +585,8 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
401
585
 
402
586
  async def get_stream(request: Request):
403
587
 
404
- async with session_manager.session(request=request):
588
+ async with session_manager.session(request=request,
589
+ user_authentication_callback=self._http_flow_handler.authenticate):
405
590
 
406
591
  return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
407
592
  content=generate_streaming_response_as_str(
@@ -435,7 +620,8 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
435
620
 
436
621
  response.headers["Content-Type"] = "application/json"
437
622
 
438
- async with session_manager.session(request=request):
623
+ async with session_manager.session(request=request,
624
+ user_authentication_callback=self._http_flow_handler.authenticate):
439
625
 
440
626
  return await generate_single_response(payload, session_manager, result_type=result_type)
441
627
 
@@ -448,7 +634,8 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
448
634
 
449
635
  async def post_stream(request: Request, payload: request_type):
450
636
 
451
- async with session_manager.session(request=request):
637
+ async with session_manager.session(request=request,
638
+ user_authentication_callback=self._http_flow_handler.authenticate):
452
639
 
453
640
  return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
454
641
  content=generate_streaming_response_as_str(
@@ -482,7 +669,206 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
482
669
 
483
670
  return post_stream
484
671
 
672
+ def post_openai_api_compatible_endpoint(request_type: type):
673
+ """
674
+ OpenAI-compatible endpoint that handles both streaming and non-streaming
675
+ based on the 'stream' parameter in the request.
676
+ """
677
+
678
+ async def post_openai_api_compatible(response: Response, request: Request, payload: request_type):
679
+ # Check if streaming is requested
680
+ stream_requested = getattr(payload, 'stream', False)
681
+
682
+ async with session_manager.session(request=request):
683
+ if stream_requested:
684
+ # Return streaming response
685
+ return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
686
+ content=generate_streaming_response_as_str(
687
+ payload,
688
+ session_manager=session_manager,
689
+ streaming=True,
690
+ step_adaptor=self.get_step_adaptor(),
691
+ result_type=AIQChatResponseChunk,
692
+ output_type=AIQChatResponseChunk))
693
+ else:
694
+ # Return single response - check if workflow supports non-streaming
695
+ try:
696
+ response.headers["Content-Type"] = "application/json"
697
+ return await generate_single_response(payload, session_manager, result_type=AIQChatResponse)
698
+ except ValueError as e:
699
+ if "Cannot get a single output value for streaming workflows" in str(e):
700
+ # Workflow only supports streaming, but client requested non-streaming
701
+ # Fall back to streaming and collect the result
702
+ chunks = []
703
+ async for chunk_str in generate_streaming_response_as_str(
704
+ payload,
705
+ session_manager=session_manager,
706
+ streaming=True,
707
+ step_adaptor=self.get_step_adaptor(),
708
+ result_type=AIQChatResponseChunk,
709
+ output_type=AIQChatResponseChunk):
710
+ if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
711
+ chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
712
+ if chunk_data:
713
+ try:
714
+ chunk_json = AIQChatResponseChunk.model_validate_json(chunk_data)
715
+ if (chunk_json.choices and len(chunk_json.choices) > 0
716
+ and chunk_json.choices[0].delta
717
+ and chunk_json.choices[0].delta.content is not None):
718
+ chunks.append(chunk_json.choices[0].delta.content)
719
+ except Exception:
720
+ continue
721
+
722
+ # Create a single response from collected chunks
723
+ content = "".join(chunks)
724
+ single_response = AIQChatResponse.from_string(content)
725
+ response.headers["Content-Type"] = "application/json"
726
+ return single_response
727
+ else:
728
+ raise
729
+
730
+ return post_openai_api_compatible
731
+
732
+ async def run_generation(job_id: str,
733
+ payload: typing.Any,
734
+ session_manager: AIQSessionManager,
735
+ result_type: type):
736
+ """Background task to run the evaluation."""
737
+ async with async_job_concurrency:
738
+ try:
739
+ result = await generate_single_response(payload=payload,
740
+ session_manager=session_manager,
741
+ result_type=result_type)
742
+ job_store.update_status(job_id, "success", output=result)
743
+ except Exception as e:
744
+ logger.error("Error in evaluation job %s: %s", job_id, e)
745
+ job_store.update_status(job_id, "failure", error=str(e))
746
+
747
+ def _job_status_to_response(job: JobInfo) -> AIQAsyncGenerationStatusResponse:
748
+ job_output = job.output
749
+ if job_output is not None:
750
+ job_output = job_output.model_dump()
751
+ return AIQAsyncGenerationStatusResponse(job_id=job.job_id,
752
+ status=job.status,
753
+ error=job.error,
754
+ output=job_output,
755
+ created_at=job.created_at,
756
+ updated_at=job.updated_at,
757
+ expires_at=job_store.get_expires_at(job))
758
+
759
+ def post_async_generation(request_type: type, final_result_type: type):
760
+
761
+ async def start_async_generation(
762
+ request: request_type, background_tasks: BackgroundTasks, response: Response,
763
+ http_request: Request) -> AIQAsyncGenerateResponse | AIQAsyncGenerationStatusResponse:
764
+ """Handle async generation requests."""
765
+
766
+ async with session_manager.session(request=http_request):
767
+
768
+ # if job_id is present and already exists return the job info
769
+ if request.job_id:
770
+ job = job_store.get_job(request.job_id)
771
+ if job:
772
+ return AIQAsyncGenerateResponse(job_id=job.job_id, status=job.status)
773
+
774
+ job_id = job_store.create_job(job_id=request.job_id, expiry_seconds=request.expiry_seconds)
775
+ await self.create_cleanup_task(app=app, name="async_generation", job_store=job_store)
776
+
777
+ # The fastapi/starlette background tasks won't begin executing until after the response is sent
778
+ # to the client, so we need to wrap the task in a function, alowing us to start the task now,
779
+ # and allowing the background task function to await the results.
780
+ task = asyncio.create_task(
781
+ run_generation(job_id=job_id,
782
+ payload=request,
783
+ session_manager=session_manager,
784
+ result_type=final_result_type))
785
+
786
+ async def wrapped_task(t: asyncio.Task):
787
+ return await t
788
+
789
+ background_tasks.add_task(wrapped_task, task)
790
+
791
+ now = time.time()
792
+ sync_timeout = now + request.sync_timeout
793
+ while time.time() < sync_timeout:
794
+ job = job_store.get_job(job_id)
795
+ if job is not None and job.status not in job_store.ACTIVE_STATUS:
796
+ # If the job is done, return the result
797
+ response.status_code = 200
798
+ return _job_status_to_response(job)
799
+
800
+ # Sleep for a short time before checking again
801
+ await asyncio.sleep(0.1)
802
+
803
+ response.status_code = 202
804
+ return AIQAsyncGenerateResponse(job_id=job_id, status="submitted")
805
+
806
+ return start_async_generation
807
+
808
+ async def get_async_job_status(job_id: str, http_request: Request) -> AIQAsyncGenerationStatusResponse:
809
+ """Get the status of an async job."""
810
+ logger.info("Getting status for job %s", job_id)
811
+
812
+ async with session_manager.session(request=http_request):
813
+
814
+ job = job_store.get_job(job_id)
815
+ if not job:
816
+ logger.warning("Job %s not found", job_id)
817
+ raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
818
+
819
+ logger.info("Found job %s with status %s", job_id, job.status)
820
+ return _job_status_to_response(job)
821
+
822
+ async def websocket_endpoint(websocket: WebSocket):
823
+
824
+ # Universal cookie handling: works for both cross-origin and same-origin connections
825
+ session_id = websocket.query_params.get("session")
826
+ if session_id:
827
+ headers = list(websocket.scope.get("headers", []))
828
+ cookie_header = f"aiqtoolkit-session={session_id}"
829
+
830
+ # Check if the session cookie already exists to avoid duplicates
831
+ cookie_exists = False
832
+ existing_session_cookie = False
833
+
834
+ for i, (name, value) in enumerate(headers):
835
+ if name == b"cookie":
836
+ cookie_exists = True
837
+ cookie_str = value.decode()
838
+
839
+ # Check if aiqtoolkit-session already exists in cookies
840
+ if "aiqtoolkit-session=" in cookie_str:
841
+ existing_session_cookie = True
842
+ logger.info("WebSocket: Session cookie already present in headers (same-origin)")
843
+ else:
844
+ # Append to existing cookie header (cross-origin case)
845
+ headers[i] = (name, f"{cookie_str}; {cookie_header}".encode())
846
+ logger.info("WebSocket: Added session cookie to existing cookie header: %s",
847
+ session_id[:10] + "...")
848
+ break
849
+
850
+ # Add new cookie header only if no cookies exist and no session cookie found
851
+ if not cookie_exists and not existing_session_cookie:
852
+ headers.append((b"cookie", cookie_header.encode()))
853
+ logger.info("WebSocket: Added new session cookie header: %s", session_id[:10] + "...")
854
+
855
+ # Update the websocket scope with the modified headers
856
+ websocket.scope["headers"] = headers
857
+
858
+ async with WebSocketMessageHandler(websocket, session_manager, self.get_step_adaptor()) as handler:
859
+
860
+ flow_handler = WebSocketAuthenticationFlowHandler(self._add_flow, self._remove_flow, handler)
861
+
862
+ # Ugly hack to set the flow handler on the message handler. Both need eachother to be set.
863
+ handler.set_flow_handler(flow_handler)
864
+
865
+ await handler.run()
866
+
867
+ if (endpoint.websocket_path):
868
+ app.add_websocket_route(endpoint.websocket_path, websocket_endpoint)
869
+
485
870
  if (endpoint.path):
871
+
486
872
  if (endpoint.method == "GET"):
487
873
 
488
874
  app.add_api_route(
@@ -554,9 +940,31 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
554
940
  responses={500: response_500},
555
941
  )
556
942
 
943
+ app.add_api_route(
944
+ path=f"{endpoint.path}/async",
945
+ endpoint=post_async_generation(request_type=AIQAsyncGenerateRequest,
946
+ final_result_type=GenerateSingleResponseType),
947
+ methods=[endpoint.method],
948
+ response_model=AIQAsyncGenerateResponse | AIQAsyncGenerationStatusResponse,
949
+ description="Start an async generate job",
950
+ responses={500: response_500},
951
+ )
557
952
  else:
558
953
  raise ValueError(f"Unsupported method {endpoint.method}")
559
954
 
955
+ app.add_api_route(
956
+ path=f"{endpoint.path}/async/job/{{job_id}}",
957
+ endpoint=get_async_job_status,
958
+ methods=["GET"],
959
+ response_model=AIQAsyncGenerationStatusResponse,
960
+ description="Get the status of an async job",
961
+ responses={
962
+ 404: {
963
+ "description": "Job not found"
964
+ }, 500: response_500
965
+ },
966
+ )
967
+
560
968
  if (endpoint.openai_api_path):
561
969
  if (endpoint.method == "GET"):
562
970
 
@@ -582,26 +990,103 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
582
990
 
583
991
  elif (endpoint.method == "POST"):
584
992
 
585
- app.add_api_route(
586
- path=endpoint.openai_api_path,
587
- endpoint=post_single_endpoint(request_type=AIQChatRequest, result_type=AIQChatResponse),
588
- methods=[endpoint.method],
589
- response_model=AIQChatResponse,
590
- description=endpoint.description,
591
- responses={500: response_500},
592
- )
593
-
594
- app.add_api_route(
595
- path=f"{endpoint.openai_api_path}/stream",
596
- endpoint=post_streaming_endpoint(request_type=AIQChatRequest,
597
- streaming=True,
598
- result_type=AIQChatResponseChunk,
599
- output_type=AIQChatResponseChunk),
600
- methods=[endpoint.method],
601
- response_model=AIQChatResponseChunk | AIQResponseIntermediateStep,
602
- description=endpoint.description,
603
- responses={500: response_500},
604
- )
993
+ # Check if OpenAI v1 compatible endpoint is configured
994
+ openai_v1_path = getattr(endpoint, 'openai_api_v1_path', None)
995
+
996
+ # Always create legacy endpoints for backward compatibility (unless they conflict with v1 path)
997
+ if not openai_v1_path or openai_v1_path != endpoint.openai_api_path:
998
+ # <openai_api_path> = non-streaming (legacy behavior)
999
+ app.add_api_route(
1000
+ path=endpoint.openai_api_path,
1001
+ endpoint=post_single_endpoint(request_type=AIQChatRequest, result_type=AIQChatResponse),
1002
+ methods=[endpoint.method],
1003
+ response_model=AIQChatResponse,
1004
+ description=endpoint.description,
1005
+ responses={500: response_500},
1006
+ )
1007
+
1008
+ # <openai_api_path>/stream = streaming (legacy behavior)
1009
+ app.add_api_route(
1010
+ path=f"{endpoint.openai_api_path}/stream",
1011
+ endpoint=post_streaming_endpoint(request_type=AIQChatRequest,
1012
+ streaming=True,
1013
+ result_type=AIQChatResponseChunk,
1014
+ output_type=AIQChatResponseChunk),
1015
+ methods=[endpoint.method],
1016
+ response_model=AIQChatResponseChunk | AIQResponseIntermediateStep,
1017
+ description=endpoint.description,
1018
+ responses={500: response_500},
1019
+ )
1020
+
1021
+ # Create OpenAI v1 compatible endpoint if configured
1022
+ if openai_v1_path:
1023
+ # OpenAI v1 Compatible Mode: Create single endpoint that handles both streaming and non-streaming
1024
+ app.add_api_route(
1025
+ path=openai_v1_path,
1026
+ endpoint=post_openai_api_compatible_endpoint(request_type=AIQChatRequest),
1027
+ methods=[endpoint.method],
1028
+ response_model=AIQChatResponse | AIQChatResponseChunk,
1029
+ description=f"{endpoint.description} (OpenAI Chat Completions API compatible)",
1030
+ responses={500: response_500},
1031
+ )
605
1032
 
606
1033
  else:
607
1034
  raise ValueError(f"Unsupported method {endpoint.method}")
1035
+
1036
+ async def add_authorization_route(self, app: FastAPI):
1037
+
1038
+ from fastapi.responses import HTMLResponse
1039
+
1040
+ from aiq.front_ends.fastapi.html_snippets.auth_code_grant_success import AUTH_REDIRECT_SUCCESS_HTML
1041
+
1042
+ async def redirect_uri(request: Request):
1043
+ """
1044
+ Handle the redirect URI for OAuth2 authentication.
1045
+ Args:
1046
+ request: The FastAPI request object containing query parameters.
1047
+
1048
+ Returns:
1049
+ HTMLResponse: A response indicating the success of the authentication flow.
1050
+ """
1051
+ state = request.query_params.get("state")
1052
+
1053
+ async with self._outstanding_flows_lock:
1054
+ if not state or state not in self._outstanding_flows:
1055
+ return "Invalid state. Please restart the authentication process."
1056
+
1057
+ flow_state = self._outstanding_flows[state]
1058
+
1059
+ config = flow_state.config
1060
+ verifier = flow_state.verifier
1061
+ client = flow_state.client
1062
+
1063
+ try:
1064
+ res = await client.fetch_token(url=config.token_url,
1065
+ authorization_response=str(request.url),
1066
+ code_verifier=verifier,
1067
+ state=state)
1068
+ flow_state.future.set_result(res)
1069
+ except Exception as e:
1070
+ flow_state.future.set_exception(e)
1071
+
1072
+ return HTMLResponse(content=AUTH_REDIRECT_SUCCESS_HTML,
1073
+ status_code=200,
1074
+ headers={
1075
+ "Content-Type": "text/html; charset=utf-8", "Cache-Control": "no-cache"
1076
+ })
1077
+
1078
+ if (self.front_end_config.oauth2_callback_path):
1079
+ # Add the redirect URI route
1080
+ app.add_api_route(
1081
+ path=self.front_end_config.oauth2_callback_path,
1082
+ endpoint=redirect_uri,
1083
+ methods=["GET"],
1084
+ description="Handles the authorization code and state returned from the Authorization Code Grant Flow.")
1085
+
1086
+ async def _add_flow(self, state: str, flow_state: FlowState):
1087
+ async with self._outstanding_flows_lock:
1088
+ self._outstanding_flows[state] = flow_state
1089
+
1090
+ async def _remove_flow(self, state: str):
1091
+ async with self._outstanding_flows_lock:
1092
+ del self._outstanding_flows[state]