aiqtoolkit 1.1.0a20250429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aiqtoolkit might be problematic. Click here for more details.

Files changed (309) hide show
  1. aiq/agent/__init__.py +0 -0
  2. aiq/agent/base.py +76 -0
  3. aiq/agent/dual_node.py +67 -0
  4. aiq/agent/react_agent/__init__.py +0 -0
  5. aiq/agent/react_agent/agent.py +322 -0
  6. aiq/agent/react_agent/output_parser.py +104 -0
  7. aiq/agent/react_agent/prompt.py +46 -0
  8. aiq/agent/react_agent/register.py +148 -0
  9. aiq/agent/reasoning_agent/__init__.py +0 -0
  10. aiq/agent/reasoning_agent/reasoning_agent.py +224 -0
  11. aiq/agent/register.py +23 -0
  12. aiq/agent/rewoo_agent/__init__.py +0 -0
  13. aiq/agent/rewoo_agent/agent.py +410 -0
  14. aiq/agent/rewoo_agent/prompt.py +108 -0
  15. aiq/agent/rewoo_agent/register.py +158 -0
  16. aiq/agent/tool_calling_agent/__init__.py +0 -0
  17. aiq/agent/tool_calling_agent/agent.py +123 -0
  18. aiq/agent/tool_calling_agent/register.py +105 -0
  19. aiq/builder/__init__.py +0 -0
  20. aiq/builder/builder.py +223 -0
  21. aiq/builder/component_utils.py +303 -0
  22. aiq/builder/context.py +198 -0
  23. aiq/builder/embedder.py +24 -0
  24. aiq/builder/eval_builder.py +116 -0
  25. aiq/builder/evaluator.py +29 -0
  26. aiq/builder/framework_enum.py +24 -0
  27. aiq/builder/front_end.py +73 -0
  28. aiq/builder/function.py +297 -0
  29. aiq/builder/function_base.py +372 -0
  30. aiq/builder/function_info.py +627 -0
  31. aiq/builder/intermediate_step_manager.py +125 -0
  32. aiq/builder/llm.py +25 -0
  33. aiq/builder/retriever.py +25 -0
  34. aiq/builder/user_interaction_manager.py +71 -0
  35. aiq/builder/workflow.py +134 -0
  36. aiq/builder/workflow_builder.py +733 -0
  37. aiq/cli/__init__.py +14 -0
  38. aiq/cli/cli_utils/__init__.py +0 -0
  39. aiq/cli/cli_utils/config_override.py +233 -0
  40. aiq/cli/cli_utils/validation.py +37 -0
  41. aiq/cli/commands/__init__.py +0 -0
  42. aiq/cli/commands/configure/__init__.py +0 -0
  43. aiq/cli/commands/configure/channel/__init__.py +0 -0
  44. aiq/cli/commands/configure/channel/add.py +28 -0
  45. aiq/cli/commands/configure/channel/channel.py +34 -0
  46. aiq/cli/commands/configure/channel/remove.py +30 -0
  47. aiq/cli/commands/configure/channel/update.py +30 -0
  48. aiq/cli/commands/configure/configure.py +33 -0
  49. aiq/cli/commands/evaluate.py +139 -0
  50. aiq/cli/commands/info/__init__.py +14 -0
  51. aiq/cli/commands/info/info.py +37 -0
  52. aiq/cli/commands/info/list_channels.py +32 -0
  53. aiq/cli/commands/info/list_components.py +129 -0
  54. aiq/cli/commands/registry/__init__.py +14 -0
  55. aiq/cli/commands/registry/publish.py +88 -0
  56. aiq/cli/commands/registry/pull.py +118 -0
  57. aiq/cli/commands/registry/registry.py +36 -0
  58. aiq/cli/commands/registry/remove.py +108 -0
  59. aiq/cli/commands/registry/search.py +155 -0
  60. aiq/cli/commands/start.py +250 -0
  61. aiq/cli/commands/uninstall.py +83 -0
  62. aiq/cli/commands/validate.py +47 -0
  63. aiq/cli/commands/workflow/__init__.py +14 -0
  64. aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  65. aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
  66. aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
  67. aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
  68. aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
  69. aiq/cli/commands/workflow/workflow.py +37 -0
  70. aiq/cli/commands/workflow/workflow_commands.py +307 -0
  71. aiq/cli/entrypoint.py +133 -0
  72. aiq/cli/main.py +44 -0
  73. aiq/cli/register_workflow.py +408 -0
  74. aiq/cli/type_registry.py +869 -0
  75. aiq/data_models/__init__.py +14 -0
  76. aiq/data_models/api_server.py +550 -0
  77. aiq/data_models/common.py +143 -0
  78. aiq/data_models/component.py +46 -0
  79. aiq/data_models/component_ref.py +135 -0
  80. aiq/data_models/config.py +349 -0
  81. aiq/data_models/dataset_handler.py +122 -0
  82. aiq/data_models/discovery_metadata.py +269 -0
  83. aiq/data_models/embedder.py +26 -0
  84. aiq/data_models/evaluate.py +101 -0
  85. aiq/data_models/evaluator.py +26 -0
  86. aiq/data_models/front_end.py +26 -0
  87. aiq/data_models/function.py +30 -0
  88. aiq/data_models/function_dependencies.py +64 -0
  89. aiq/data_models/interactive.py +237 -0
  90. aiq/data_models/intermediate_step.py +269 -0
  91. aiq/data_models/invocation_node.py +38 -0
  92. aiq/data_models/llm.py +26 -0
  93. aiq/data_models/logging.py +26 -0
  94. aiq/data_models/memory.py +26 -0
  95. aiq/data_models/profiler.py +53 -0
  96. aiq/data_models/registry_handler.py +26 -0
  97. aiq/data_models/retriever.py +30 -0
  98. aiq/data_models/step_adaptor.py +64 -0
  99. aiq/data_models/streaming.py +33 -0
  100. aiq/data_models/swe_bench_model.py +54 -0
  101. aiq/data_models/telemetry_exporter.py +26 -0
  102. aiq/embedder/__init__.py +0 -0
  103. aiq/embedder/langchain_client.py +41 -0
  104. aiq/embedder/nim_embedder.py +58 -0
  105. aiq/embedder/openai_embedder.py +42 -0
  106. aiq/embedder/register.py +24 -0
  107. aiq/eval/__init__.py +14 -0
  108. aiq/eval/config.py +42 -0
  109. aiq/eval/dataset_handler/__init__.py +0 -0
  110. aiq/eval/dataset_handler/dataset_downloader.py +106 -0
  111. aiq/eval/dataset_handler/dataset_filter.py +52 -0
  112. aiq/eval/dataset_handler/dataset_handler.py +164 -0
  113. aiq/eval/evaluate.py +322 -0
  114. aiq/eval/evaluator/__init__.py +14 -0
  115. aiq/eval/evaluator/evaluator_model.py +44 -0
  116. aiq/eval/intermediate_step_adapter.py +93 -0
  117. aiq/eval/rag_evaluator/__init__.py +0 -0
  118. aiq/eval/rag_evaluator/evaluate.py +138 -0
  119. aiq/eval/rag_evaluator/register.py +138 -0
  120. aiq/eval/register.py +22 -0
  121. aiq/eval/remote_workflow.py +128 -0
  122. aiq/eval/runtime_event_subscriber.py +52 -0
  123. aiq/eval/swe_bench_evaluator/__init__.py +0 -0
  124. aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
  125. aiq/eval/swe_bench_evaluator/register.py +36 -0
  126. aiq/eval/trajectory_evaluator/__init__.py +0 -0
  127. aiq/eval/trajectory_evaluator/evaluate.py +118 -0
  128. aiq/eval/trajectory_evaluator/register.py +40 -0
  129. aiq/eval/utils/__init__.py +0 -0
  130. aiq/eval/utils/output_uploader.py +131 -0
  131. aiq/eval/utils/tqdm_position_registry.py +40 -0
  132. aiq/front_ends/__init__.py +14 -0
  133. aiq/front_ends/console/__init__.py +14 -0
  134. aiq/front_ends/console/console_front_end_config.py +32 -0
  135. aiq/front_ends/console/console_front_end_plugin.py +107 -0
  136. aiq/front_ends/console/register.py +25 -0
  137. aiq/front_ends/cron/__init__.py +14 -0
  138. aiq/front_ends/fastapi/__init__.py +14 -0
  139. aiq/front_ends/fastapi/fastapi_front_end_config.py +150 -0
  140. aiq/front_ends/fastapi/fastapi_front_end_plugin.py +103 -0
  141. aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +574 -0
  142. aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  143. aiq/front_ends/fastapi/job_store.py +161 -0
  144. aiq/front_ends/fastapi/main.py +70 -0
  145. aiq/front_ends/fastapi/message_handler.py +279 -0
  146. aiq/front_ends/fastapi/message_validator.py +345 -0
  147. aiq/front_ends/fastapi/register.py +25 -0
  148. aiq/front_ends/fastapi/response_helpers.py +181 -0
  149. aiq/front_ends/fastapi/step_adaptor.py +315 -0
  150. aiq/front_ends/fastapi/websocket.py +148 -0
  151. aiq/front_ends/mcp/__init__.py +14 -0
  152. aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
  153. aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
  154. aiq/front_ends/mcp/register.py +27 -0
  155. aiq/front_ends/mcp/tool_converter.py +242 -0
  156. aiq/front_ends/register.py +22 -0
  157. aiq/front_ends/simple_base/__init__.py +14 -0
  158. aiq/front_ends/simple_base/simple_front_end_plugin_base.py +52 -0
  159. aiq/llm/__init__.py +0 -0
  160. aiq/llm/nim_llm.py +45 -0
  161. aiq/llm/openai_llm.py +45 -0
  162. aiq/llm/register.py +22 -0
  163. aiq/llm/utils/__init__.py +14 -0
  164. aiq/llm/utils/env_config_value.py +94 -0
  165. aiq/llm/utils/error.py +17 -0
  166. aiq/memory/__init__.py +20 -0
  167. aiq/memory/interfaces.py +183 -0
  168. aiq/memory/models.py +102 -0
  169. aiq/meta/module_to_distro.json +3 -0
  170. aiq/meta/pypi.md +59 -0
  171. aiq/observability/__init__.py +0 -0
  172. aiq/observability/async_otel_listener.py +270 -0
  173. aiq/observability/register.py +97 -0
  174. aiq/plugins/.namespace +1 -0
  175. aiq/profiler/__init__.py +0 -0
  176. aiq/profiler/callbacks/__init__.py +0 -0
  177. aiq/profiler/callbacks/agno_callback_handler.py +295 -0
  178. aiq/profiler/callbacks/base_callback_class.py +20 -0
  179. aiq/profiler/callbacks/langchain_callback_handler.py +278 -0
  180. aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
  181. aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  182. aiq/profiler/callbacks/token_usage_base_model.py +27 -0
  183. aiq/profiler/data_frame_row.py +51 -0
  184. aiq/profiler/decorators/__init__.py +0 -0
  185. aiq/profiler/decorators/framework_wrapper.py +131 -0
  186. aiq/profiler/decorators/function_tracking.py +254 -0
  187. aiq/profiler/forecasting/__init__.py +0 -0
  188. aiq/profiler/forecasting/config.py +18 -0
  189. aiq/profiler/forecasting/model_trainer.py +75 -0
  190. aiq/profiler/forecasting/models/__init__.py +22 -0
  191. aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
  192. aiq/profiler/forecasting/models/linear_model.py +196 -0
  193. aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
  194. aiq/profiler/inference_metrics_model.py +25 -0
  195. aiq/profiler/inference_optimization/__init__.py +0 -0
  196. aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  197. aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +452 -0
  198. aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  199. aiq/profiler/inference_optimization/data_models.py +386 -0
  200. aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
  201. aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  202. aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
  203. aiq/profiler/inference_optimization/llm_metrics.py +212 -0
  204. aiq/profiler/inference_optimization/prompt_caching.py +163 -0
  205. aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
  206. aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
  207. aiq/profiler/intermediate_property_adapter.py +102 -0
  208. aiq/profiler/profile_runner.py +433 -0
  209. aiq/profiler/utils.py +184 -0
  210. aiq/registry_handlers/__init__.py +0 -0
  211. aiq/registry_handlers/local/__init__.py +0 -0
  212. aiq/registry_handlers/local/local_handler.py +176 -0
  213. aiq/registry_handlers/local/register_local.py +37 -0
  214. aiq/registry_handlers/metadata_factory.py +60 -0
  215. aiq/registry_handlers/package_utils.py +198 -0
  216. aiq/registry_handlers/pypi/__init__.py +0 -0
  217. aiq/registry_handlers/pypi/pypi_handler.py +251 -0
  218. aiq/registry_handlers/pypi/register_pypi.py +40 -0
  219. aiq/registry_handlers/register.py +21 -0
  220. aiq/registry_handlers/registry_handler_base.py +157 -0
  221. aiq/registry_handlers/rest/__init__.py +0 -0
  222. aiq/registry_handlers/rest/register_rest.py +56 -0
  223. aiq/registry_handlers/rest/rest_handler.py +237 -0
  224. aiq/registry_handlers/schemas/__init__.py +0 -0
  225. aiq/registry_handlers/schemas/headers.py +42 -0
  226. aiq/registry_handlers/schemas/package.py +68 -0
  227. aiq/registry_handlers/schemas/publish.py +63 -0
  228. aiq/registry_handlers/schemas/pull.py +81 -0
  229. aiq/registry_handlers/schemas/remove.py +36 -0
  230. aiq/registry_handlers/schemas/search.py +91 -0
  231. aiq/registry_handlers/schemas/status.py +47 -0
  232. aiq/retriever/__init__.py +0 -0
  233. aiq/retriever/interface.py +37 -0
  234. aiq/retriever/milvus/__init__.py +14 -0
  235. aiq/retriever/milvus/register.py +81 -0
  236. aiq/retriever/milvus/retriever.py +228 -0
  237. aiq/retriever/models.py +74 -0
  238. aiq/retriever/nemo_retriever/__init__.py +14 -0
  239. aiq/retriever/nemo_retriever/register.py +60 -0
  240. aiq/retriever/nemo_retriever/retriever.py +190 -0
  241. aiq/retriever/register.py +22 -0
  242. aiq/runtime/__init__.py +14 -0
  243. aiq/runtime/loader.py +188 -0
  244. aiq/runtime/runner.py +176 -0
  245. aiq/runtime/session.py +116 -0
  246. aiq/settings/__init__.py +0 -0
  247. aiq/settings/global_settings.py +318 -0
  248. aiq/test/.namespace +1 -0
  249. aiq/tool/__init__.py +0 -0
  250. aiq/tool/code_execution/__init__.py +0 -0
  251. aiq/tool/code_execution/code_sandbox.py +188 -0
  252. aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  253. aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
  254. aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +79 -0
  255. aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +4 -0
  256. aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +25 -0
  257. aiq/tool/code_execution/register.py +70 -0
  258. aiq/tool/code_execution/utils.py +100 -0
  259. aiq/tool/datetime_tools.py +42 -0
  260. aiq/tool/document_search.py +141 -0
  261. aiq/tool/github_tools/__init__.py +0 -0
  262. aiq/tool/github_tools/create_github_commit.py +133 -0
  263. aiq/tool/github_tools/create_github_issue.py +87 -0
  264. aiq/tool/github_tools/create_github_pr.py +106 -0
  265. aiq/tool/github_tools/get_github_file.py +106 -0
  266. aiq/tool/github_tools/get_github_issue.py +166 -0
  267. aiq/tool/github_tools/get_github_pr.py +256 -0
  268. aiq/tool/github_tools/update_github_issue.py +100 -0
  269. aiq/tool/mcp/__init__.py +14 -0
  270. aiq/tool/mcp/mcp_client.py +220 -0
  271. aiq/tool/mcp/mcp_tool.py +75 -0
  272. aiq/tool/memory_tools/__init__.py +0 -0
  273. aiq/tool/memory_tools/add_memory_tool.py +67 -0
  274. aiq/tool/memory_tools/delete_memory_tool.py +67 -0
  275. aiq/tool/memory_tools/get_memory_tool.py +72 -0
  276. aiq/tool/nvidia_rag.py +95 -0
  277. aiq/tool/register.py +36 -0
  278. aiq/tool/retriever.py +89 -0
  279. aiq/utils/__init__.py +0 -0
  280. aiq/utils/data_models/__init__.py +0 -0
  281. aiq/utils/data_models/schema_validator.py +58 -0
  282. aiq/utils/debugging_utils.py +43 -0
  283. aiq/utils/exception_handlers/__init__.py +0 -0
  284. aiq/utils/exception_handlers/schemas.py +114 -0
  285. aiq/utils/io/__init__.py +0 -0
  286. aiq/utils/io/yaml_tools.py +50 -0
  287. aiq/utils/metadata_utils.py +74 -0
  288. aiq/utils/producer_consumer_queue.py +178 -0
  289. aiq/utils/reactive/__init__.py +0 -0
  290. aiq/utils/reactive/base/__init__.py +0 -0
  291. aiq/utils/reactive/base/observable_base.py +65 -0
  292. aiq/utils/reactive/base/observer_base.py +55 -0
  293. aiq/utils/reactive/base/subject_base.py +79 -0
  294. aiq/utils/reactive/observable.py +59 -0
  295. aiq/utils/reactive/observer.py +76 -0
  296. aiq/utils/reactive/subject.py +131 -0
  297. aiq/utils/reactive/subscription.py +49 -0
  298. aiq/utils/settings/__init__.py +0 -0
  299. aiq/utils/settings/global_settings.py +197 -0
  300. aiq/utils/type_converter.py +232 -0
  301. aiq/utils/type_utils.py +397 -0
  302. aiq/utils/url_utils.py +27 -0
  303. aiqtoolkit-1.1.0a20250429.dist-info/METADATA +326 -0
  304. aiqtoolkit-1.1.0a20250429.dist-info/RECORD +309 -0
  305. aiqtoolkit-1.1.0a20250429.dist-info/WHEEL +5 -0
  306. aiqtoolkit-1.1.0a20250429.dist-info/entry_points.txt +17 -0
  307. aiqtoolkit-1.1.0a20250429.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
  308. aiqtoolkit-1.1.0a20250429.dist-info/licenses/LICENSE.md +201 -0
  309. aiqtoolkit-1.1.0a20250429.dist-info/top_level.txt +1 -0
aiq/runtime/session.py ADDED
@@ -0,0 +1,116 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import contextvars
18
+ import typing
19
+ from collections.abc import Awaitable
20
+ from collections.abc import Callable
21
+ from contextlib import asynccontextmanager
22
+ from contextlib import nullcontext
23
+
24
+ from aiq.builder.context import AIQContext
25
+ from aiq.builder.context import AIQContextState
26
+ from aiq.builder.workflow import Workflow
27
+ from aiq.data_models.config import AIQConfig
28
+ from aiq.data_models.interactive import HumanResponse
29
+ from aiq.data_models.interactive import InteractionPrompt
30
+
31
+ _T = typing.TypeVar("_T")
32
+
33
+
34
+ class UserManagerBase:
35
+ pass
36
+
37
+
38
+ class AIQSessionManager:
39
+
40
+ def __init__(self, workflow: Workflow, max_concurrency: int = 8):
41
+ """
42
+ The AIQSessionManager class is used to run and manage a user workflow session. It runs and manages the context,
43
+ and configuration of a workflow with the specified concurrency.
44
+
45
+ Parameters
46
+ ----------
47
+ workflow : Workflow
48
+ The workflow to run
49
+ max_concurrency : int, optional
50
+ The maximum number of simultaneous workflow invocations, by default 8
51
+ """
52
+
53
+ if (workflow is None):
54
+ raise ValueError("Workflow cannot be None")
55
+
56
+ self._workflow: Workflow = workflow
57
+
58
+ self._max_concurrency = max_concurrency
59
+ self._context_state = AIQContextState.get()
60
+ self._context = AIQContext(self._context_state)
61
+
62
+ # We save the context because Uvicorn spawns a new process
63
+ # for each request, and we need to restore the context vars
64
+ self._saved_context = contextvars.copy_context()
65
+
66
+ if (max_concurrency > 0):
67
+ self._semaphore = asyncio.Semaphore(max_concurrency)
68
+ else:
69
+ # If max_concurrency is 0, then we don't need to limit the concurrency but we still need a context
70
+ self._semaphore = nullcontext()
71
+
72
+ @property
73
+ def config(self) -> AIQConfig:
74
+ return self._workflow.config
75
+
76
+ @property
77
+ def workflow(self) -> Workflow:
78
+ return self._workflow
79
+
80
+ @property
81
+ def context(self) -> AIQContext:
82
+ return self._context
83
+
84
+ @asynccontextmanager
85
+ async def session(self,
86
+ user_manager=None,
87
+ user_input_callback: Callable[[InteractionPrompt], Awaitable[HumanResponse]] = None):
88
+
89
+ token_user_input = None
90
+ if user_input_callback is not None:
91
+ token_user_input = self._context_state.user_input_callback.set(user_input_callback)
92
+
93
+ token_user_manager = None
94
+ if user_manager is not None:
95
+ token_user_manager = self._context_state.user_manager.set(user_manager)
96
+
97
+ try:
98
+ yield self
99
+ finally:
100
+ if token_user_manager is not None:
101
+ self._context_state.user_manager.reset(token_user_manager)
102
+ if token_user_input is not None:
103
+ self._context_state.user_input_callback.reset(token_user_input)
104
+
105
+ @asynccontextmanager
106
+ async def run(self, message):
107
+ """
108
+ Start a workflow run
109
+ """
110
+ async with self._semaphore:
111
+ # Apply the saved context
112
+ for k, v in self._saved_context.items():
113
+ k.set(v)
114
+
115
+ async with self._workflow.run(message) as runner:
116
+ yield runner
File without changes
@@ -0,0 +1,318 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import logging
18
+ import os
19
+ import typing
20
+ from collections.abc import Callable
21
+ from contextlib import contextmanager
22
+ from copy import deepcopy
23
+
24
+ from platformdirs import user_config_dir
25
+ from pydantic import ConfigDict
26
+ from pydantic import Discriminator
27
+ from pydantic import Tag
28
+ from pydantic import ValidationError
29
+ from pydantic import ValidationInfo
30
+ from pydantic import ValidatorFunctionWrapHandler
31
+ from pydantic import field_validator
32
+
33
+ from aiq.cli.type_registry import GlobalTypeRegistry
34
+ from aiq.cli.type_registry import RegisteredInfo
35
+ from aiq.data_models.common import HashableBaseModel
36
+ from aiq.data_models.common import TypedBaseModel
37
+ from aiq.data_models.common import TypedBaseModelT
38
+ from aiq.data_models.registry_handler import RegistryHandlerBaseConfig
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ class Settings(HashableBaseModel):
44
+
45
+ model_config = ConfigDict(extra="forbid")
46
+
47
+ # Registry Handeler Configuration
48
+ channels: dict[str, RegistryHandlerBaseConfig] = {}
49
+
50
+ _configuration_directory: typing.ClassVar[str]
51
+ _settings_changed_hooks: typing.ClassVar[list[Callable[[], None]]] = []
52
+ _settings_changed_hooks_active: bool = True
53
+
54
+ @field_validator("channels", mode="wrap")
55
+ @classmethod
56
+ def validate_components(cls, value: typing.Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
57
+
58
+ try:
59
+ return handler(value)
60
+ except ValidationError as err:
61
+
62
+ for e in err.errors():
63
+ if e['type'] == 'union_tag_invalid' and len(e['loc']) > 0:
64
+ requested_type = e['loc'][0]
65
+
66
+ if (info.field_name == "channels"):
67
+ registered_keys = GlobalTypeRegistry.get().get_registered_registry_handlers()
68
+ else:
69
+ assert False, f"Unknown field name {info.field_name} in validator"
70
+
71
+ # Check and see if the there are multiple full types which match this short type
72
+ matching_keys = [k for k in registered_keys if k.local_name == requested_type]
73
+
74
+ assert len(matching_keys) != 1, "Exact match should have been found. Contact developers"
75
+
76
+ matching_key_names = [x.full_type for x in matching_keys]
77
+ registered_key_names = [x.full_type for x in registered_keys]
78
+
79
+ if (len(matching_keys) == 0):
80
+ # This is a case where the requested type is not found. Show a helpful message about what is
81
+ # available
82
+ raise ValueError(
83
+ f"Requested {info.field_name} type `{requested_type}` not found. "
84
+ "Have you ensured the necessary package has been installed with `uv pip install`?"
85
+ "\nAvailable {} names:\n - {}".format(info.field_name,
86
+ '\n - '.join(registered_key_names))) from err
87
+
88
+ # This is a case where the requested type is ambiguous.
89
+ raise ValueError(f"Requested {info.field_name} type `{requested_type}` is ambiguous. " +
90
+ f"Matched multiple {info.field_name} by their local name: {matching_key_names}. " +
91
+ f"Please use the fully qualified {info.field_name} name." +
92
+ "\nAvailable {} names:\n - {}".format(info.field_name,
93
+ '\n - '.join(registered_key_names))) from err
94
+
95
+ raise
96
+
97
+ @classmethod
98
+ def rebuild_annotations(cls):
99
+
100
+ def compute_annotation(cls: type[TypedBaseModelT], registrations: list[RegisteredInfo[TypedBaseModelT]]):
101
+
102
+ while (len(registrations) < 2):
103
+ registrations.append(RegisteredInfo[TypedBaseModelT](full_type=f"_ignore/{len(registrations)}",
104
+ config_type=cls))
105
+
106
+ short_names: dict[str, int] = {}
107
+ type_list: list[tuple[str, type[TypedBaseModelT]]] = []
108
+
109
+ # For all keys in the list, split the key by / and increment the count of the last element
110
+ for key in registrations:
111
+ short_names[key.local_name] = short_names.get(key.local_name, 0) + 1
112
+
113
+ type_list.append((key.full_type, key.config_type))
114
+
115
+ # Now loop again and if the short name is unique, then create two entries, for the short and full name
116
+ for key in registrations:
117
+
118
+ if (short_names[key.local_name] == 1):
119
+ type_list.append((key.local_name, key.config_type))
120
+
121
+ # pylint: disable=consider-alternative-union-syntax
122
+ return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
123
+
124
+ RegistryHandlerAnnotation = dict[
125
+ str,
126
+ typing.Annotated[compute_annotation(RegistryHandlerBaseConfig,
127
+ GlobalTypeRegistry.get().get_registered_registry_handlers()),
128
+ Discriminator(TypedBaseModel.discriminator)]]
129
+
130
+ should_rebuild = False
131
+
132
+ channels_field = cls.model_fields.get("channels")
133
+ if channels_field is not None and channels_field.annotation != RegistryHandlerAnnotation:
134
+ channels_field.annotation = RegistryHandlerAnnotation
135
+ should_rebuild = True
136
+
137
+ if (should_rebuild):
138
+ cls.model_rebuild(force=True)
139
+
140
+ @property
141
+ def channel_names(self) -> list:
142
+ return list(self.channels.keys())
143
+
144
+ @property
145
+ def configuration_directory(self) -> str:
146
+ return self._configuration_directory
147
+
148
+ @property
149
+ def configuration_file(self) -> str:
150
+ return os.path.join(self.configuration_directory, "config.json")
151
+
152
+ @staticmethod
153
+ def from_file():
154
+
155
+ configuration_directory = os.getenv("AIQ_CONFIG_DIR", user_config_dir(appname="aiq"))
156
+
157
+ if not os.path.exists(configuration_directory):
158
+ os.makedirs(configuration_directory, exist_ok=True)
159
+
160
+ configuration_file = os.path.join(configuration_directory, "config.json")
161
+
162
+ file_path = os.path.join(configuration_directory, "config.json")
163
+
164
+ if (not os.path.exists(configuration_file)):
165
+ loaded_config = {}
166
+ else:
167
+ with open(file_path, mode="r", encoding="utf-8") as f:
168
+ loaded_config = json.load(f)
169
+
170
+ settings = Settings(**loaded_config)
171
+ settings.set_configuration_directory(configuration_directory)
172
+ return settings
173
+
174
+ def set_configuration_directory(self, directory: str, remove: bool = False) -> None:
175
+ if (remove):
176
+ if os.path.exists(self.configuration_directory):
177
+ os.rmdir(self.configuration_directory)
178
+ self.__class__._configuration_directory = directory
179
+
180
+ def reset_configuration_directory(self, remove: bool = False) -> None:
181
+ if (remove):
182
+ if os.path.exists(self.configuration_directory):
183
+ os.rmdir(self.configuration_directory)
184
+ self._configuration_directory = os.getenv("AIQ_CONFIG_DIR", user_config_dir(appname="aiq"))
185
+
186
+ def _save_settings(self) -> None:
187
+
188
+ if not os.path.exists(self.configuration_directory):
189
+ os.mkdir(self.configuration_directory)
190
+
191
+ with open(self.configuration_file, mode="w", encoding="utf-8") as f:
192
+ f.write(self.model_dump_json(indent=4, by_alias=True, serialize_as_any=True))
193
+
194
+ self._settings_changed()
195
+
196
+ def update_settings(self, config_obj: "dict | Settings"):
197
+ self._update_settings(config_obj)
198
+
199
+ def _update_settings(self, config_obj: "dict | Settings"):
200
+
201
+ if isinstance(config_obj, Settings):
202
+ config_obj = config_obj.model_dump(serialize_as_any=True, by_alias=True)
203
+
204
+ self._revalidate(config_dict=config_obj)
205
+
206
+ self._save_settings()
207
+
208
+ def _revalidate(self, config_dict) -> bool:
209
+
210
+ try:
211
+ validated_data = self.__class__(**config_dict)
212
+
213
+ for field in validated_data.model_fields:
214
+ match field:
215
+ case "channels":
216
+ self.channels = validated_data.channels
217
+ case _:
218
+ raise ValueError(f"Encountered invalid model field: {field}")
219
+
220
+ return True
221
+
222
+ except Exception as e:
223
+ logger.exception("Unable to validate user settings configuration: %s", e, exc_info=True)
224
+ return False
225
+
226
+ def print_channel_settings(self, channel_type: str | None = None) -> None:
227
+
228
+ import yaml
229
+
230
+ remote_channels = self.model_dump(serialize_as_any=True, by_alias=True)
231
+
232
+ if (not remote_channels or not remote_channels.get("channels")):
233
+ logger.warning("No configured channels to list.")
234
+ return
235
+
236
+ if (channel_type is not None):
237
+ filter_channels = []
238
+ for channel, settings in remote_channels.items():
239
+ if (settings["type"] != channel_type):
240
+ filter_channels.append(channel)
241
+ for channel in filter_channels:
242
+ del remote_channels[channel]
243
+
244
+ if (remote_channels):
245
+ logger.info(yaml.dump(remote_channels, allow_unicode=True, default_flow_style=False))
246
+
247
+ def override_settings(self, config_file: str) -> "Settings":
248
+
249
+ from aiq.utils.io.yaml_tools import yaml_load
250
+
251
+ override_settings_dict = yaml_load(config_file)
252
+
253
+ settings_dict = self.model_dump()
254
+ updated_settings = {**override_settings_dict, **settings_dict}
255
+ self._update_settings(config_obj=updated_settings)
256
+
257
+ return self
258
+
259
+ def _settings_changed(self):
260
+
261
+ if (not self._settings_changed_hooks_active):
262
+ return
263
+
264
+ for hook in self._settings_changed_hooks:
265
+ hook()
266
+
267
+ @contextmanager
268
+ def pause_settings_changed_hooks(self):
269
+
270
+ self._settings_changed_hooks_active = False
271
+
272
+ try:
273
+ yield
274
+ finally:
275
+ self._settings_changed_hooks_active = True
276
+
277
+ # Ensure that the registration changed hooks are called
278
+ self._settings_changed()
279
+
280
+ def add_settings_changed_hook(self, cb: Callable[[], None]) -> None:
281
+
282
+ self._settings_changed_hooks.append(cb)
283
+
284
+
285
+ GlobalTypeRegistry.get().add_registration_changed_hook(lambda: Settings.rebuild_annotations())
286
+
287
+
288
+ class GlobalSettings:
289
+
290
+ _global_settings: Settings | None = None
291
+
292
+ @staticmethod
293
+ def get() -> Settings:
294
+
295
+ if (GlobalSettings._global_settings is None):
296
+ from aiq.runtime.loader import PluginTypes
297
+ from aiq.runtime.loader import discover_and_register_plugins
298
+
299
+ discover_and_register_plugins(PluginTypes.REGISTRY_HANDLER)
300
+
301
+ GlobalSettings._global_settings = Settings.from_file()
302
+
303
+ return GlobalSettings._global_settings
304
+
305
+ @staticmethod
306
+ @contextmanager
307
+ def push():
308
+
309
+ saved = GlobalSettings.get()
310
+ settings = deepcopy(saved)
311
+
312
+ try:
313
+ GlobalSettings._global_settings = settings
314
+
315
+ yield settings
316
+ finally:
317
+ GlobalSettings._global_settings = saved
318
+ GlobalSettings._global_settings._settings_changed()
aiq/test/.namespace ADDED
@@ -0,0 +1 @@
1
+ Note: This is a python namespace package and this directory should remain empty. Do NOT add a `__init__.py` file or any other files to this directory. This file is also needed to ensure the directory exists in git.
aiq/tool/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,188 @@
1
+ # Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ import json
17
+ import logging
18
+ from urllib.parse import urljoin
19
+
20
+ import requests
21
+ from pydantic import HttpUrl
22
+
23
+ logger = logging.getLogger(__file__)
24
+
25
+
26
+ class Sandbox(abc.ABC):
27
+ """Code execution sandbox.
28
+
29
+ Args:
30
+ host: Optional[str] = '127.0.0.1' - Host of the sandbox server.
31
+ Can also be specified through NEMO_SKILLS_SANDBOX_HOST env var.
32
+ port: Optional[str] = '5000' - Port of the sandbox server.
33
+ Can also be specified through NEMO_SKILLS_SANDBOX_PORT env var.
34
+ ssh_server: Optional[str] = None - SSH server for tunneling requests.
35
+ Useful if server is running on slurm cluster to which there is an ssh access.
36
+ Can also be specified through NEMO_SKILLS_SSH_SERVER env var.
37
+ ssh_key_path: Optional[str] = None - Path to the ssh key for tunneling.
38
+ Can also be specified through NEMO_SKILLS_SSH_KEY_PATH env var.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ *,
44
+ uri: HttpUrl,
45
+ ):
46
+ self.url = self._get_execute_url(uri)
47
+ session = requests.Session()
48
+ adapter = requests.adapters.HTTPAdapter(pool_maxsize=1500, pool_connections=1500, max_retries=3)
49
+ session.mount('http://', adapter)
50
+ session.mount('https://', adapter)
51
+ self.http_session = session
52
+
53
+ def _send_request(self, request, timeout):
54
+ output = self.http_session.post(
55
+ url=self.url,
56
+ data=json.dumps(request),
57
+ timeout=timeout,
58
+ headers={"Content-Type": "application/json"},
59
+ )
60
+ # retrying 502 errors
61
+ if output.status_code == 502:
62
+ raise requests.exceptions.Timeout
63
+
64
+ return self._parse_request_output(output)
65
+
66
+ @abc.abstractmethod
67
+ def _parse_request_output(self, output):
68
+ pass
69
+
70
+ @abc.abstractmethod
71
+ def _get_execute_url(self, uri):
72
+ pass
73
+
74
+ @abc.abstractmethod
75
+ def _prepare_request(self, generated_code, timeout):
76
+ pass
77
+
78
+ async def execute_code(
79
+ self,
80
+ generated_code: str,
81
+ timeout: float = 10.0,
82
+ language: str = "python",
83
+ max_output_characters: int = 1000,
84
+ ) -> tuple[dict, str]:
85
+
86
+ generated_code = generated_code.lstrip().rstrip().lstrip("`").rstrip("`")
87
+ code_to_execute = """
88
+ import traceback
89
+ import json
90
+ import os
91
+ import warnings
92
+ import contextlib
93
+ import io
94
+ warnings.filterwarnings('ignore')
95
+ os.environ['OPENBLAS_NUM_THREADS'] = '16'
96
+ """
97
+
98
+ code_to_execute += f"""
99
+ \ngenerated_code = {repr(generated_code)}\n
100
+ stdout = io.StringIO()
101
+ stderr = io.StringIO()
102
+
103
+ with contextlib.redirect_stdout(stdout), contextlib.redirect_stderr(stderr):
104
+ try:
105
+ exec(generated_code)
106
+ status = "completed"
107
+ except Exception:
108
+ status = "error"
109
+ stderr.write(traceback.format_exc())
110
+ stdout = stdout.getvalue()
111
+ stderr = stderr.getvalue()
112
+ if len(stdout) > {max_output_characters}:
113
+ stdout = stdout[:{max_output_characters}] + "<output cut>"
114
+ if len(stderr) > {max_output_characters}:
115
+ stderr = stderr[:{max_output_characters}] + "<output cut>"
116
+ if stdout:
117
+ stdout += "\\n"
118
+ if stderr:
119
+ stderr += "\\n"
120
+ output = {{"process_status": status, "stdout": stdout, "stderr": stderr}}
121
+ print(json.dumps(output))
122
+ """
123
+ request = self._prepare_request(code_to_execute, timeout)
124
+ try:
125
+ output = self._send_request(request, timeout)
126
+ except requests.exceptions.Timeout:
127
+ output = {"process_status": "timeout", "stdout": "", "stderr": "Timed out\n"}
128
+ return output
129
+
130
+
131
+ class LocalSandbox(Sandbox):
132
+ """Locally hosted sandbox."""
133
+
134
+ def _get_execute_url(self, uri):
135
+ return urljoin(str(uri), "execute")
136
+
137
+ def _parse_request_output(self, output):
138
+ try:
139
+ return output.json()
140
+ except json.JSONDecodeError as e:
141
+ logger.exception("Error parsing output: %s. %s", output.text, e)
142
+ return {'process_status': 'error', 'stdout': '', 'stderr': 'Unknown error'}
143
+
144
+ def _prepare_request(self, generated_code, timeout, language='python', **kwargs):
145
+ return {
146
+ "generated_code": generated_code,
147
+ "timeout": timeout,
148
+ "language": language,
149
+ }
150
+
151
+
152
+ class PistonSandbox(Sandbox):
153
+ """Piston sandbox (https://github.com/engineer-man/piston)"""
154
+
155
+ def _get_execute_url(self, uri):
156
+ return urljoin(str(uri), "execute")
157
+
158
+ def _parse_request_output(self, output):
159
+ output = output.json()
160
+ if output['run']['signal'] == "SIGKILL":
161
+ return {'result': None, 'error_message': 'Unknown error: SIGKILL'}
162
+ return json.loads(output['run']['output'])
163
+
164
+ def _prepare_request(self, generated_code: str, timeout, **kwargs):
165
+ return {
166
+ "language": "py",
167
+ "version": "3.10.0",
168
+ "files": [{
169
+ "content": generated_code,
170
+ }],
171
+ "stdin": "",
172
+ "args": [],
173
+ "run_timeout": timeout * 1000.0, # milliseconds
174
+ "compile_memory_limit": -1,
175
+ "run_memory_limit": -1,
176
+ }
177
+
178
+
179
+ sandboxes = {
180
+ 'local': LocalSandbox,
181
+ 'piston': PistonSandbox,
182
+ }
183
+
184
+
185
+ def get_sandbox(sandbox_type: str = "local", **kwargs):
186
+ """A helper function to make it easier to set sandbox through cmd."""
187
+ sandbox_class = sandboxes[sandbox_type.lower()]
188
+ return sandbox_class(**kwargs)
@@ -0,0 +1,60 @@
1
+ # Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Use the base image with Python 3.10 and Flask
16
+ FROM tiangolo/uwsgi-nginx-flask:python3.10
17
+
18
+ # Install dependencies required for Lean 4 and other tools
19
+ RUN apt-get update && \
20
+ apt-get install -y curl git && \
21
+ curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | sh -s -- -y && \
22
+ /root/.elan/bin/elan toolchain install leanprover/lean4:v4.12.0 && \
23
+ /root/.elan/bin/elan default leanprover/lean4:v4.12.0 && \
24
+ /root/.elan/bin/elan self update
25
+
26
+ # Set environment variables to include Lean and elan/lake in the PATH
27
+ ENV PATH="/root/.elan/bin:$PATH"
28
+
29
+ # Create Lean project directory and initialize a new Lean project with Mathlib4
30
+ RUN mkdir -p /lean4 && cd /lean4 && \
31
+ /root/.elan/bin/lake new my_project && \
32
+ cd my_project && \
33
+ echo 'leanprover/lean4:v4.12.0' > lean-toolchain && \
34
+ echo 'require mathlib from git "https://github.com/leanprover-community/mathlib4" @ "v4.12.0"' >> lakefile.lean
35
+
36
+ # Download and cache Mathlib4 to avoid recompiling, then build the project
37
+ RUN cd /lean4/my_project && \
38
+ /root/.elan/bin/lake exe cache get && \
39
+ /root/.elan/bin/lake build
40
+
41
+ # Set environment variables to include Lean project path
42
+ ENV LEAN_PATH="/lean4/my_project"
43
+ ENV PATH="/lean4/my_project:$PATH"
44
+
45
+ # Set up application code and install Python dependencies
46
+ COPY sandbox.requirements.txt /app/requirements.txt
47
+ RUN pip install --no-cache-dir -r /app/requirements.txt
48
+ COPY local_sandbox_server.py /app/main.py
49
+
50
+ # Set the working directory to /app
51
+ WORKDIR /app
52
+
53
+ # Set Flask app environment variables and ports
54
+ ARG UWSGI_CHEAPER
55
+ ENV UWSGI_CHEAPER=$UWSGI_CHEAPER
56
+
57
+ ARG UWSGI_PROCESSES
58
+ ENV UWSGI_PROCESSES=$UWSGI_PROCESSES
59
+
60
+ ENV LISTEN_PORT=6000