aiqtoolkit 1.1.0a20250429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/__init__.py +0 -0
- aiq/agent/base.py +76 -0
- aiq/agent/dual_node.py +67 -0
- aiq/agent/react_agent/__init__.py +0 -0
- aiq/agent/react_agent/agent.py +322 -0
- aiq/agent/react_agent/output_parser.py +104 -0
- aiq/agent/react_agent/prompt.py +46 -0
- aiq/agent/react_agent/register.py +148 -0
- aiq/agent/reasoning_agent/__init__.py +0 -0
- aiq/agent/reasoning_agent/reasoning_agent.py +224 -0
- aiq/agent/register.py +23 -0
- aiq/agent/rewoo_agent/__init__.py +0 -0
- aiq/agent/rewoo_agent/agent.py +410 -0
- aiq/agent/rewoo_agent/prompt.py +108 -0
- aiq/agent/rewoo_agent/register.py +158 -0
- aiq/agent/tool_calling_agent/__init__.py +0 -0
- aiq/agent/tool_calling_agent/agent.py +123 -0
- aiq/agent/tool_calling_agent/register.py +105 -0
- aiq/builder/__init__.py +0 -0
- aiq/builder/builder.py +223 -0
- aiq/builder/component_utils.py +303 -0
- aiq/builder/context.py +198 -0
- aiq/builder/embedder.py +24 -0
- aiq/builder/eval_builder.py +116 -0
- aiq/builder/evaluator.py +29 -0
- aiq/builder/framework_enum.py +24 -0
- aiq/builder/front_end.py +73 -0
- aiq/builder/function.py +297 -0
- aiq/builder/function_base.py +372 -0
- aiq/builder/function_info.py +627 -0
- aiq/builder/intermediate_step_manager.py +125 -0
- aiq/builder/llm.py +25 -0
- aiq/builder/retriever.py +25 -0
- aiq/builder/user_interaction_manager.py +71 -0
- aiq/builder/workflow.py +134 -0
- aiq/builder/workflow_builder.py +733 -0
- aiq/cli/__init__.py +14 -0
- aiq/cli/cli_utils/__init__.py +0 -0
- aiq/cli/cli_utils/config_override.py +233 -0
- aiq/cli/cli_utils/validation.py +37 -0
- aiq/cli/commands/__init__.py +0 -0
- aiq/cli/commands/configure/__init__.py +0 -0
- aiq/cli/commands/configure/channel/__init__.py +0 -0
- aiq/cli/commands/configure/channel/add.py +28 -0
- aiq/cli/commands/configure/channel/channel.py +34 -0
- aiq/cli/commands/configure/channel/remove.py +30 -0
- aiq/cli/commands/configure/channel/update.py +30 -0
- aiq/cli/commands/configure/configure.py +33 -0
- aiq/cli/commands/evaluate.py +139 -0
- aiq/cli/commands/info/__init__.py +14 -0
- aiq/cli/commands/info/info.py +37 -0
- aiq/cli/commands/info/list_channels.py +32 -0
- aiq/cli/commands/info/list_components.py +129 -0
- aiq/cli/commands/registry/__init__.py +14 -0
- aiq/cli/commands/registry/publish.py +88 -0
- aiq/cli/commands/registry/pull.py +118 -0
- aiq/cli/commands/registry/registry.py +36 -0
- aiq/cli/commands/registry/remove.py +108 -0
- aiq/cli/commands/registry/search.py +155 -0
- aiq/cli/commands/start.py +250 -0
- aiq/cli/commands/uninstall.py +83 -0
- aiq/cli/commands/validate.py +47 -0
- aiq/cli/commands/workflow/__init__.py +14 -0
- aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
- aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
- aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- aiq/cli/commands/workflow/workflow.py +37 -0
- aiq/cli/commands/workflow/workflow_commands.py +307 -0
- aiq/cli/entrypoint.py +133 -0
- aiq/cli/main.py +44 -0
- aiq/cli/register_workflow.py +408 -0
- aiq/cli/type_registry.py +869 -0
- aiq/data_models/__init__.py +14 -0
- aiq/data_models/api_server.py +550 -0
- aiq/data_models/common.py +143 -0
- aiq/data_models/component.py +46 -0
- aiq/data_models/component_ref.py +135 -0
- aiq/data_models/config.py +349 -0
- aiq/data_models/dataset_handler.py +122 -0
- aiq/data_models/discovery_metadata.py +269 -0
- aiq/data_models/embedder.py +26 -0
- aiq/data_models/evaluate.py +101 -0
- aiq/data_models/evaluator.py +26 -0
- aiq/data_models/front_end.py +26 -0
- aiq/data_models/function.py +30 -0
- aiq/data_models/function_dependencies.py +64 -0
- aiq/data_models/interactive.py +237 -0
- aiq/data_models/intermediate_step.py +269 -0
- aiq/data_models/invocation_node.py +38 -0
- aiq/data_models/llm.py +26 -0
- aiq/data_models/logging.py +26 -0
- aiq/data_models/memory.py +26 -0
- aiq/data_models/profiler.py +53 -0
- aiq/data_models/registry_handler.py +26 -0
- aiq/data_models/retriever.py +30 -0
- aiq/data_models/step_adaptor.py +64 -0
- aiq/data_models/streaming.py +33 -0
- aiq/data_models/swe_bench_model.py +54 -0
- aiq/data_models/telemetry_exporter.py +26 -0
- aiq/embedder/__init__.py +0 -0
- aiq/embedder/langchain_client.py +41 -0
- aiq/embedder/nim_embedder.py +58 -0
- aiq/embedder/openai_embedder.py +42 -0
- aiq/embedder/register.py +24 -0
- aiq/eval/__init__.py +14 -0
- aiq/eval/config.py +42 -0
- aiq/eval/dataset_handler/__init__.py +0 -0
- aiq/eval/dataset_handler/dataset_downloader.py +106 -0
- aiq/eval/dataset_handler/dataset_filter.py +52 -0
- aiq/eval/dataset_handler/dataset_handler.py +164 -0
- aiq/eval/evaluate.py +322 -0
- aiq/eval/evaluator/__init__.py +14 -0
- aiq/eval/evaluator/evaluator_model.py +44 -0
- aiq/eval/intermediate_step_adapter.py +93 -0
- aiq/eval/rag_evaluator/__init__.py +0 -0
- aiq/eval/rag_evaluator/evaluate.py +138 -0
- aiq/eval/rag_evaluator/register.py +138 -0
- aiq/eval/register.py +22 -0
- aiq/eval/remote_workflow.py +128 -0
- aiq/eval/runtime_event_subscriber.py +52 -0
- aiq/eval/swe_bench_evaluator/__init__.py +0 -0
- aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
- aiq/eval/swe_bench_evaluator/register.py +36 -0
- aiq/eval/trajectory_evaluator/__init__.py +0 -0
- aiq/eval/trajectory_evaluator/evaluate.py +118 -0
- aiq/eval/trajectory_evaluator/register.py +40 -0
- aiq/eval/utils/__init__.py +0 -0
- aiq/eval/utils/output_uploader.py +131 -0
- aiq/eval/utils/tqdm_position_registry.py +40 -0
- aiq/front_ends/__init__.py +14 -0
- aiq/front_ends/console/__init__.py +14 -0
- aiq/front_ends/console/console_front_end_config.py +32 -0
- aiq/front_ends/console/console_front_end_plugin.py +107 -0
- aiq/front_ends/console/register.py +25 -0
- aiq/front_ends/cron/__init__.py +14 -0
- aiq/front_ends/fastapi/__init__.py +14 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +150 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +103 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +574 -0
- aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- aiq/front_ends/fastapi/job_store.py +161 -0
- aiq/front_ends/fastapi/main.py +70 -0
- aiq/front_ends/fastapi/message_handler.py +279 -0
- aiq/front_ends/fastapi/message_validator.py +345 -0
- aiq/front_ends/fastapi/register.py +25 -0
- aiq/front_ends/fastapi/response_helpers.py +181 -0
- aiq/front_ends/fastapi/step_adaptor.py +315 -0
- aiq/front_ends/fastapi/websocket.py +148 -0
- aiq/front_ends/mcp/__init__.py +14 -0
- aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
- aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
- aiq/front_ends/mcp/register.py +27 -0
- aiq/front_ends/mcp/tool_converter.py +242 -0
- aiq/front_ends/register.py +22 -0
- aiq/front_ends/simple_base/__init__.py +14 -0
- aiq/front_ends/simple_base/simple_front_end_plugin_base.py +52 -0
- aiq/llm/__init__.py +0 -0
- aiq/llm/nim_llm.py +45 -0
- aiq/llm/openai_llm.py +45 -0
- aiq/llm/register.py +22 -0
- aiq/llm/utils/__init__.py +14 -0
- aiq/llm/utils/env_config_value.py +94 -0
- aiq/llm/utils/error.py +17 -0
- aiq/memory/__init__.py +20 -0
- aiq/memory/interfaces.py +183 -0
- aiq/memory/models.py +102 -0
- aiq/meta/module_to_distro.json +3 -0
- aiq/meta/pypi.md +59 -0
- aiq/observability/__init__.py +0 -0
- aiq/observability/async_otel_listener.py +270 -0
- aiq/observability/register.py +97 -0
- aiq/plugins/.namespace +1 -0
- aiq/profiler/__init__.py +0 -0
- aiq/profiler/callbacks/__init__.py +0 -0
- aiq/profiler/callbacks/agno_callback_handler.py +295 -0
- aiq/profiler/callbacks/base_callback_class.py +20 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +278 -0
- aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
- aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- aiq/profiler/callbacks/token_usage_base_model.py +27 -0
- aiq/profiler/data_frame_row.py +51 -0
- aiq/profiler/decorators/__init__.py +0 -0
- aiq/profiler/decorators/framework_wrapper.py +131 -0
- aiq/profiler/decorators/function_tracking.py +254 -0
- aiq/profiler/forecasting/__init__.py +0 -0
- aiq/profiler/forecasting/config.py +18 -0
- aiq/profiler/forecasting/model_trainer.py +75 -0
- aiq/profiler/forecasting/models/__init__.py +22 -0
- aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
- aiq/profiler/forecasting/models/linear_model.py +196 -0
- aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
- aiq/profiler/inference_metrics_model.py +25 -0
- aiq/profiler/inference_optimization/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +452 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- aiq/profiler/inference_optimization/data_models.py +386 -0
- aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
- aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- aiq/profiler/inference_optimization/llm_metrics.py +212 -0
- aiq/profiler/inference_optimization/prompt_caching.py +163 -0
- aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
- aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
- aiq/profiler/intermediate_property_adapter.py +102 -0
- aiq/profiler/profile_runner.py +433 -0
- aiq/profiler/utils.py +184 -0
- aiq/registry_handlers/__init__.py +0 -0
- aiq/registry_handlers/local/__init__.py +0 -0
- aiq/registry_handlers/local/local_handler.py +176 -0
- aiq/registry_handlers/local/register_local.py +37 -0
- aiq/registry_handlers/metadata_factory.py +60 -0
- aiq/registry_handlers/package_utils.py +198 -0
- aiq/registry_handlers/pypi/__init__.py +0 -0
- aiq/registry_handlers/pypi/pypi_handler.py +251 -0
- aiq/registry_handlers/pypi/register_pypi.py +40 -0
- aiq/registry_handlers/register.py +21 -0
- aiq/registry_handlers/registry_handler_base.py +157 -0
- aiq/registry_handlers/rest/__init__.py +0 -0
- aiq/registry_handlers/rest/register_rest.py +56 -0
- aiq/registry_handlers/rest/rest_handler.py +237 -0
- aiq/registry_handlers/schemas/__init__.py +0 -0
- aiq/registry_handlers/schemas/headers.py +42 -0
- aiq/registry_handlers/schemas/package.py +68 -0
- aiq/registry_handlers/schemas/publish.py +63 -0
- aiq/registry_handlers/schemas/pull.py +81 -0
- aiq/registry_handlers/schemas/remove.py +36 -0
- aiq/registry_handlers/schemas/search.py +91 -0
- aiq/registry_handlers/schemas/status.py +47 -0
- aiq/retriever/__init__.py +0 -0
- aiq/retriever/interface.py +37 -0
- aiq/retriever/milvus/__init__.py +14 -0
- aiq/retriever/milvus/register.py +81 -0
- aiq/retriever/milvus/retriever.py +228 -0
- aiq/retriever/models.py +74 -0
- aiq/retriever/nemo_retriever/__init__.py +14 -0
- aiq/retriever/nemo_retriever/register.py +60 -0
- aiq/retriever/nemo_retriever/retriever.py +190 -0
- aiq/retriever/register.py +22 -0
- aiq/runtime/__init__.py +14 -0
- aiq/runtime/loader.py +188 -0
- aiq/runtime/runner.py +176 -0
- aiq/runtime/session.py +116 -0
- aiq/settings/__init__.py +0 -0
- aiq/settings/global_settings.py +318 -0
- aiq/test/.namespace +1 -0
- aiq/tool/__init__.py +0 -0
- aiq/tool/code_execution/__init__.py +0 -0
- aiq/tool/code_execution/code_sandbox.py +188 -0
- aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +79 -0
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +4 -0
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +25 -0
- aiq/tool/code_execution/register.py +70 -0
- aiq/tool/code_execution/utils.py +100 -0
- aiq/tool/datetime_tools.py +42 -0
- aiq/tool/document_search.py +141 -0
- aiq/tool/github_tools/__init__.py +0 -0
- aiq/tool/github_tools/create_github_commit.py +133 -0
- aiq/tool/github_tools/create_github_issue.py +87 -0
- aiq/tool/github_tools/create_github_pr.py +106 -0
- aiq/tool/github_tools/get_github_file.py +106 -0
- aiq/tool/github_tools/get_github_issue.py +166 -0
- aiq/tool/github_tools/get_github_pr.py +256 -0
- aiq/tool/github_tools/update_github_issue.py +100 -0
- aiq/tool/mcp/__init__.py +14 -0
- aiq/tool/mcp/mcp_client.py +220 -0
- aiq/tool/mcp/mcp_tool.py +75 -0
- aiq/tool/memory_tools/__init__.py +0 -0
- aiq/tool/memory_tools/add_memory_tool.py +67 -0
- aiq/tool/memory_tools/delete_memory_tool.py +67 -0
- aiq/tool/memory_tools/get_memory_tool.py +72 -0
- aiq/tool/nvidia_rag.py +95 -0
- aiq/tool/register.py +36 -0
- aiq/tool/retriever.py +89 -0
- aiq/utils/__init__.py +0 -0
- aiq/utils/data_models/__init__.py +0 -0
- aiq/utils/data_models/schema_validator.py +58 -0
- aiq/utils/debugging_utils.py +43 -0
- aiq/utils/exception_handlers/__init__.py +0 -0
- aiq/utils/exception_handlers/schemas.py +114 -0
- aiq/utils/io/__init__.py +0 -0
- aiq/utils/io/yaml_tools.py +50 -0
- aiq/utils/metadata_utils.py +74 -0
- aiq/utils/producer_consumer_queue.py +178 -0
- aiq/utils/reactive/__init__.py +0 -0
- aiq/utils/reactive/base/__init__.py +0 -0
- aiq/utils/reactive/base/observable_base.py +65 -0
- aiq/utils/reactive/base/observer_base.py +55 -0
- aiq/utils/reactive/base/subject_base.py +79 -0
- aiq/utils/reactive/observable.py +59 -0
- aiq/utils/reactive/observer.py +76 -0
- aiq/utils/reactive/subject.py +131 -0
- aiq/utils/reactive/subscription.py +49 -0
- aiq/utils/settings/__init__.py +0 -0
- aiq/utils/settings/global_settings.py +197 -0
- aiq/utils/type_converter.py +232 -0
- aiq/utils/type_utils.py +397 -0
- aiq/utils/url_utils.py +27 -0
- aiqtoolkit-1.1.0a20250429.dist-info/METADATA +326 -0
- aiqtoolkit-1.1.0a20250429.dist-info/RECORD +309 -0
- aiqtoolkit-1.1.0a20250429.dist-info/WHEEL +5 -0
- aiqtoolkit-1.1.0a20250429.dist-info/entry_points.txt +17 -0
- aiqtoolkit-1.1.0a20250429.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
- aiqtoolkit-1.1.0a20250429.dist-info/licenses/LICENSE.md +201 -0
- aiqtoolkit-1.1.0a20250429.dist-info/top_level.txt +1 -0
aiq/runtime/session.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import contextvars
|
|
18
|
+
import typing
|
|
19
|
+
from collections.abc import Awaitable
|
|
20
|
+
from collections.abc import Callable
|
|
21
|
+
from contextlib import asynccontextmanager
|
|
22
|
+
from contextlib import nullcontext
|
|
23
|
+
|
|
24
|
+
from aiq.builder.context import AIQContext
|
|
25
|
+
from aiq.builder.context import AIQContextState
|
|
26
|
+
from aiq.builder.workflow import Workflow
|
|
27
|
+
from aiq.data_models.config import AIQConfig
|
|
28
|
+
from aiq.data_models.interactive import HumanResponse
|
|
29
|
+
from aiq.data_models.interactive import InteractionPrompt
|
|
30
|
+
|
|
31
|
+
_T = typing.TypeVar("_T")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class UserManagerBase:
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class AIQSessionManager:
|
|
39
|
+
|
|
40
|
+
def __init__(self, workflow: Workflow, max_concurrency: int = 8):
|
|
41
|
+
"""
|
|
42
|
+
The AIQSessionManager class is used to run and manage a user workflow session. It runs and manages the context,
|
|
43
|
+
and configuration of a workflow with the specified concurrency.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
workflow : Workflow
|
|
48
|
+
The workflow to run
|
|
49
|
+
max_concurrency : int, optional
|
|
50
|
+
The maximum number of simultaneous workflow invocations, by default 8
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
if (workflow is None):
|
|
54
|
+
raise ValueError("Workflow cannot be None")
|
|
55
|
+
|
|
56
|
+
self._workflow: Workflow = workflow
|
|
57
|
+
|
|
58
|
+
self._max_concurrency = max_concurrency
|
|
59
|
+
self._context_state = AIQContextState.get()
|
|
60
|
+
self._context = AIQContext(self._context_state)
|
|
61
|
+
|
|
62
|
+
# We save the context because Uvicorn spawns a new process
|
|
63
|
+
# for each request, and we need to restore the context vars
|
|
64
|
+
self._saved_context = contextvars.copy_context()
|
|
65
|
+
|
|
66
|
+
if (max_concurrency > 0):
|
|
67
|
+
self._semaphore = asyncio.Semaphore(max_concurrency)
|
|
68
|
+
else:
|
|
69
|
+
# If max_concurrency is 0, then we don't need to limit the concurrency but we still need a context
|
|
70
|
+
self._semaphore = nullcontext()
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def config(self) -> AIQConfig:
|
|
74
|
+
return self._workflow.config
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def workflow(self) -> Workflow:
|
|
78
|
+
return self._workflow
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def context(self) -> AIQContext:
|
|
82
|
+
return self._context
|
|
83
|
+
|
|
84
|
+
@asynccontextmanager
|
|
85
|
+
async def session(self,
|
|
86
|
+
user_manager=None,
|
|
87
|
+
user_input_callback: Callable[[InteractionPrompt], Awaitable[HumanResponse]] = None):
|
|
88
|
+
|
|
89
|
+
token_user_input = None
|
|
90
|
+
if user_input_callback is not None:
|
|
91
|
+
token_user_input = self._context_state.user_input_callback.set(user_input_callback)
|
|
92
|
+
|
|
93
|
+
token_user_manager = None
|
|
94
|
+
if user_manager is not None:
|
|
95
|
+
token_user_manager = self._context_state.user_manager.set(user_manager)
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
yield self
|
|
99
|
+
finally:
|
|
100
|
+
if token_user_manager is not None:
|
|
101
|
+
self._context_state.user_manager.reset(token_user_manager)
|
|
102
|
+
if token_user_input is not None:
|
|
103
|
+
self._context_state.user_input_callback.reset(token_user_input)
|
|
104
|
+
|
|
105
|
+
@asynccontextmanager
|
|
106
|
+
async def run(self, message):
|
|
107
|
+
"""
|
|
108
|
+
Start a workflow run
|
|
109
|
+
"""
|
|
110
|
+
async with self._semaphore:
|
|
111
|
+
# Apply the saved context
|
|
112
|
+
for k, v in self._saved_context.items():
|
|
113
|
+
k.set(v)
|
|
114
|
+
|
|
115
|
+
async with self._workflow.run(message) as runner:
|
|
116
|
+
yield runner
|
aiq/settings/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import typing
|
|
20
|
+
from collections.abc import Callable
|
|
21
|
+
from contextlib import contextmanager
|
|
22
|
+
from copy import deepcopy
|
|
23
|
+
|
|
24
|
+
from platformdirs import user_config_dir
|
|
25
|
+
from pydantic import ConfigDict
|
|
26
|
+
from pydantic import Discriminator
|
|
27
|
+
from pydantic import Tag
|
|
28
|
+
from pydantic import ValidationError
|
|
29
|
+
from pydantic import ValidationInfo
|
|
30
|
+
from pydantic import ValidatorFunctionWrapHandler
|
|
31
|
+
from pydantic import field_validator
|
|
32
|
+
|
|
33
|
+
from aiq.cli.type_registry import GlobalTypeRegistry
|
|
34
|
+
from aiq.cli.type_registry import RegisteredInfo
|
|
35
|
+
from aiq.data_models.common import HashableBaseModel
|
|
36
|
+
from aiq.data_models.common import TypedBaseModel
|
|
37
|
+
from aiq.data_models.common import TypedBaseModelT
|
|
38
|
+
from aiq.data_models.registry_handler import RegistryHandlerBaseConfig
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Settings(HashableBaseModel):
|
|
44
|
+
|
|
45
|
+
model_config = ConfigDict(extra="forbid")
|
|
46
|
+
|
|
47
|
+
# Registry Handeler Configuration
|
|
48
|
+
channels: dict[str, RegistryHandlerBaseConfig] = {}
|
|
49
|
+
|
|
50
|
+
_configuration_directory: typing.ClassVar[str]
|
|
51
|
+
_settings_changed_hooks: typing.ClassVar[list[Callable[[], None]]] = []
|
|
52
|
+
_settings_changed_hooks_active: bool = True
|
|
53
|
+
|
|
54
|
+
@field_validator("channels", mode="wrap")
|
|
55
|
+
@classmethod
|
|
56
|
+
def validate_components(cls, value: typing.Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
return handler(value)
|
|
60
|
+
except ValidationError as err:
|
|
61
|
+
|
|
62
|
+
for e in err.errors():
|
|
63
|
+
if e['type'] == 'union_tag_invalid' and len(e['loc']) > 0:
|
|
64
|
+
requested_type = e['loc'][0]
|
|
65
|
+
|
|
66
|
+
if (info.field_name == "channels"):
|
|
67
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_registry_handlers()
|
|
68
|
+
else:
|
|
69
|
+
assert False, f"Unknown field name {info.field_name} in validator"
|
|
70
|
+
|
|
71
|
+
# Check and see if the there are multiple full types which match this short type
|
|
72
|
+
matching_keys = [k for k in registered_keys if k.local_name == requested_type]
|
|
73
|
+
|
|
74
|
+
assert len(matching_keys) != 1, "Exact match should have been found. Contact developers"
|
|
75
|
+
|
|
76
|
+
matching_key_names = [x.full_type for x in matching_keys]
|
|
77
|
+
registered_key_names = [x.full_type for x in registered_keys]
|
|
78
|
+
|
|
79
|
+
if (len(matching_keys) == 0):
|
|
80
|
+
# This is a case where the requested type is not found. Show a helpful message about what is
|
|
81
|
+
# available
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"Requested {info.field_name} type `{requested_type}` not found. "
|
|
84
|
+
"Have you ensured the necessary package has been installed with `uv pip install`?"
|
|
85
|
+
"\nAvailable {} names:\n - {}".format(info.field_name,
|
|
86
|
+
'\n - '.join(registered_key_names))) from err
|
|
87
|
+
|
|
88
|
+
# This is a case where the requested type is ambiguous.
|
|
89
|
+
raise ValueError(f"Requested {info.field_name} type `{requested_type}` is ambiguous. " +
|
|
90
|
+
f"Matched multiple {info.field_name} by their local name: {matching_key_names}. " +
|
|
91
|
+
f"Please use the fully qualified {info.field_name} name." +
|
|
92
|
+
"\nAvailable {} names:\n - {}".format(info.field_name,
|
|
93
|
+
'\n - '.join(registered_key_names))) from err
|
|
94
|
+
|
|
95
|
+
raise
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def rebuild_annotations(cls):
|
|
99
|
+
|
|
100
|
+
def compute_annotation(cls: type[TypedBaseModelT], registrations: list[RegisteredInfo[TypedBaseModelT]]):
|
|
101
|
+
|
|
102
|
+
while (len(registrations) < 2):
|
|
103
|
+
registrations.append(RegisteredInfo[TypedBaseModelT](full_type=f"_ignore/{len(registrations)}",
|
|
104
|
+
config_type=cls))
|
|
105
|
+
|
|
106
|
+
short_names: dict[str, int] = {}
|
|
107
|
+
type_list: list[tuple[str, type[TypedBaseModelT]]] = []
|
|
108
|
+
|
|
109
|
+
# For all keys in the list, split the key by / and increment the count of the last element
|
|
110
|
+
for key in registrations:
|
|
111
|
+
short_names[key.local_name] = short_names.get(key.local_name, 0) + 1
|
|
112
|
+
|
|
113
|
+
type_list.append((key.full_type, key.config_type))
|
|
114
|
+
|
|
115
|
+
# Now loop again and if the short name is unique, then create two entries, for the short and full name
|
|
116
|
+
for key in registrations:
|
|
117
|
+
|
|
118
|
+
if (short_names[key.local_name] == 1):
|
|
119
|
+
type_list.append((key.local_name, key.config_type))
|
|
120
|
+
|
|
121
|
+
# pylint: disable=consider-alternative-union-syntax
|
|
122
|
+
return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
123
|
+
|
|
124
|
+
RegistryHandlerAnnotation = dict[
|
|
125
|
+
str,
|
|
126
|
+
typing.Annotated[compute_annotation(RegistryHandlerBaseConfig,
|
|
127
|
+
GlobalTypeRegistry.get().get_registered_registry_handlers()),
|
|
128
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
129
|
+
|
|
130
|
+
should_rebuild = False
|
|
131
|
+
|
|
132
|
+
channels_field = cls.model_fields.get("channels")
|
|
133
|
+
if channels_field is not None and channels_field.annotation != RegistryHandlerAnnotation:
|
|
134
|
+
channels_field.annotation = RegistryHandlerAnnotation
|
|
135
|
+
should_rebuild = True
|
|
136
|
+
|
|
137
|
+
if (should_rebuild):
|
|
138
|
+
cls.model_rebuild(force=True)
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def channel_names(self) -> list:
|
|
142
|
+
return list(self.channels.keys())
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def configuration_directory(self) -> str:
|
|
146
|
+
return self._configuration_directory
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def configuration_file(self) -> str:
|
|
150
|
+
return os.path.join(self.configuration_directory, "config.json")
|
|
151
|
+
|
|
152
|
+
@staticmethod
|
|
153
|
+
def from_file():
|
|
154
|
+
|
|
155
|
+
configuration_directory = os.getenv("AIQ_CONFIG_DIR", user_config_dir(appname="aiq"))
|
|
156
|
+
|
|
157
|
+
if not os.path.exists(configuration_directory):
|
|
158
|
+
os.makedirs(configuration_directory, exist_ok=True)
|
|
159
|
+
|
|
160
|
+
configuration_file = os.path.join(configuration_directory, "config.json")
|
|
161
|
+
|
|
162
|
+
file_path = os.path.join(configuration_directory, "config.json")
|
|
163
|
+
|
|
164
|
+
if (not os.path.exists(configuration_file)):
|
|
165
|
+
loaded_config = {}
|
|
166
|
+
else:
|
|
167
|
+
with open(file_path, mode="r", encoding="utf-8") as f:
|
|
168
|
+
loaded_config = json.load(f)
|
|
169
|
+
|
|
170
|
+
settings = Settings(**loaded_config)
|
|
171
|
+
settings.set_configuration_directory(configuration_directory)
|
|
172
|
+
return settings
|
|
173
|
+
|
|
174
|
+
def set_configuration_directory(self, directory: str, remove: bool = False) -> None:
|
|
175
|
+
if (remove):
|
|
176
|
+
if os.path.exists(self.configuration_directory):
|
|
177
|
+
os.rmdir(self.configuration_directory)
|
|
178
|
+
self.__class__._configuration_directory = directory
|
|
179
|
+
|
|
180
|
+
def reset_configuration_directory(self, remove: bool = False) -> None:
|
|
181
|
+
if (remove):
|
|
182
|
+
if os.path.exists(self.configuration_directory):
|
|
183
|
+
os.rmdir(self.configuration_directory)
|
|
184
|
+
self._configuration_directory = os.getenv("AIQ_CONFIG_DIR", user_config_dir(appname="aiq"))
|
|
185
|
+
|
|
186
|
+
def _save_settings(self) -> None:
|
|
187
|
+
|
|
188
|
+
if not os.path.exists(self.configuration_directory):
|
|
189
|
+
os.mkdir(self.configuration_directory)
|
|
190
|
+
|
|
191
|
+
with open(self.configuration_file, mode="w", encoding="utf-8") as f:
|
|
192
|
+
f.write(self.model_dump_json(indent=4, by_alias=True, serialize_as_any=True))
|
|
193
|
+
|
|
194
|
+
self._settings_changed()
|
|
195
|
+
|
|
196
|
+
def update_settings(self, config_obj: "dict | Settings"):
|
|
197
|
+
self._update_settings(config_obj)
|
|
198
|
+
|
|
199
|
+
def _update_settings(self, config_obj: "dict | Settings"):
|
|
200
|
+
|
|
201
|
+
if isinstance(config_obj, Settings):
|
|
202
|
+
config_obj = config_obj.model_dump(serialize_as_any=True, by_alias=True)
|
|
203
|
+
|
|
204
|
+
self._revalidate(config_dict=config_obj)
|
|
205
|
+
|
|
206
|
+
self._save_settings()
|
|
207
|
+
|
|
208
|
+
def _revalidate(self, config_dict) -> bool:
|
|
209
|
+
|
|
210
|
+
try:
|
|
211
|
+
validated_data = self.__class__(**config_dict)
|
|
212
|
+
|
|
213
|
+
for field in validated_data.model_fields:
|
|
214
|
+
match field:
|
|
215
|
+
case "channels":
|
|
216
|
+
self.channels = validated_data.channels
|
|
217
|
+
case _:
|
|
218
|
+
raise ValueError(f"Encountered invalid model field: {field}")
|
|
219
|
+
|
|
220
|
+
return True
|
|
221
|
+
|
|
222
|
+
except Exception as e:
|
|
223
|
+
logger.exception("Unable to validate user settings configuration: %s", e, exc_info=True)
|
|
224
|
+
return False
|
|
225
|
+
|
|
226
|
+
def print_channel_settings(self, channel_type: str | None = None) -> None:
|
|
227
|
+
|
|
228
|
+
import yaml
|
|
229
|
+
|
|
230
|
+
remote_channels = self.model_dump(serialize_as_any=True, by_alias=True)
|
|
231
|
+
|
|
232
|
+
if (not remote_channels or not remote_channels.get("channels")):
|
|
233
|
+
logger.warning("No configured channels to list.")
|
|
234
|
+
return
|
|
235
|
+
|
|
236
|
+
if (channel_type is not None):
|
|
237
|
+
filter_channels = []
|
|
238
|
+
for channel, settings in remote_channels.items():
|
|
239
|
+
if (settings["type"] != channel_type):
|
|
240
|
+
filter_channels.append(channel)
|
|
241
|
+
for channel in filter_channels:
|
|
242
|
+
del remote_channels[channel]
|
|
243
|
+
|
|
244
|
+
if (remote_channels):
|
|
245
|
+
logger.info(yaml.dump(remote_channels, allow_unicode=True, default_flow_style=False))
|
|
246
|
+
|
|
247
|
+
def override_settings(self, config_file: str) -> "Settings":
|
|
248
|
+
|
|
249
|
+
from aiq.utils.io.yaml_tools import yaml_load
|
|
250
|
+
|
|
251
|
+
override_settings_dict = yaml_load(config_file)
|
|
252
|
+
|
|
253
|
+
settings_dict = self.model_dump()
|
|
254
|
+
updated_settings = {**override_settings_dict, **settings_dict}
|
|
255
|
+
self._update_settings(config_obj=updated_settings)
|
|
256
|
+
|
|
257
|
+
return self
|
|
258
|
+
|
|
259
|
+
def _settings_changed(self):
|
|
260
|
+
|
|
261
|
+
if (not self._settings_changed_hooks_active):
|
|
262
|
+
return
|
|
263
|
+
|
|
264
|
+
for hook in self._settings_changed_hooks:
|
|
265
|
+
hook()
|
|
266
|
+
|
|
267
|
+
@contextmanager
|
|
268
|
+
def pause_settings_changed_hooks(self):
|
|
269
|
+
|
|
270
|
+
self._settings_changed_hooks_active = False
|
|
271
|
+
|
|
272
|
+
try:
|
|
273
|
+
yield
|
|
274
|
+
finally:
|
|
275
|
+
self._settings_changed_hooks_active = True
|
|
276
|
+
|
|
277
|
+
# Ensure that the registration changed hooks are called
|
|
278
|
+
self._settings_changed()
|
|
279
|
+
|
|
280
|
+
def add_settings_changed_hook(self, cb: Callable[[], None]) -> None:
|
|
281
|
+
|
|
282
|
+
self._settings_changed_hooks.append(cb)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
GlobalTypeRegistry.get().add_registration_changed_hook(lambda: Settings.rebuild_annotations())
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class GlobalSettings:
|
|
289
|
+
|
|
290
|
+
_global_settings: Settings | None = None
|
|
291
|
+
|
|
292
|
+
@staticmethod
|
|
293
|
+
def get() -> Settings:
|
|
294
|
+
|
|
295
|
+
if (GlobalSettings._global_settings is None):
|
|
296
|
+
from aiq.runtime.loader import PluginTypes
|
|
297
|
+
from aiq.runtime.loader import discover_and_register_plugins
|
|
298
|
+
|
|
299
|
+
discover_and_register_plugins(PluginTypes.REGISTRY_HANDLER)
|
|
300
|
+
|
|
301
|
+
GlobalSettings._global_settings = Settings.from_file()
|
|
302
|
+
|
|
303
|
+
return GlobalSettings._global_settings
|
|
304
|
+
|
|
305
|
+
@staticmethod
|
|
306
|
+
@contextmanager
|
|
307
|
+
def push():
|
|
308
|
+
|
|
309
|
+
saved = GlobalSettings.get()
|
|
310
|
+
settings = deepcopy(saved)
|
|
311
|
+
|
|
312
|
+
try:
|
|
313
|
+
GlobalSettings._global_settings = settings
|
|
314
|
+
|
|
315
|
+
yield settings
|
|
316
|
+
finally:
|
|
317
|
+
GlobalSettings._global_settings = saved
|
|
318
|
+
GlobalSettings._global_settings._settings_changed()
|
aiq/test/.namespace
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
Note: This is a python namespace package and this directory should remain empty. Do NOT add a `__init__.py` file or any other files to this directory. This file is also needed to ensure the directory exists in git.
|
aiq/tool/__init__.py
ADDED
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import abc
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
from urllib.parse import urljoin
|
|
19
|
+
|
|
20
|
+
import requests
|
|
21
|
+
from pydantic import HttpUrl
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__file__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Sandbox(abc.ABC):
|
|
27
|
+
"""Code execution sandbox.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
host: Optional[str] = '127.0.0.1' - Host of the sandbox server.
|
|
31
|
+
Can also be specified through NEMO_SKILLS_SANDBOX_HOST env var.
|
|
32
|
+
port: Optional[str] = '5000' - Port of the sandbox server.
|
|
33
|
+
Can also be specified through NEMO_SKILLS_SANDBOX_PORT env var.
|
|
34
|
+
ssh_server: Optional[str] = None - SSH server for tunneling requests.
|
|
35
|
+
Useful if server is running on slurm cluster to which there is an ssh access.
|
|
36
|
+
Can also be specified through NEMO_SKILLS_SSH_SERVER env var.
|
|
37
|
+
ssh_key_path: Optional[str] = None - Path to the ssh key for tunneling.
|
|
38
|
+
Can also be specified through NEMO_SKILLS_SSH_KEY_PATH env var.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
*,
|
|
44
|
+
uri: HttpUrl,
|
|
45
|
+
):
|
|
46
|
+
self.url = self._get_execute_url(uri)
|
|
47
|
+
session = requests.Session()
|
|
48
|
+
adapter = requests.adapters.HTTPAdapter(pool_maxsize=1500, pool_connections=1500, max_retries=3)
|
|
49
|
+
session.mount('http://', adapter)
|
|
50
|
+
session.mount('https://', adapter)
|
|
51
|
+
self.http_session = session
|
|
52
|
+
|
|
53
|
+
def _send_request(self, request, timeout):
|
|
54
|
+
output = self.http_session.post(
|
|
55
|
+
url=self.url,
|
|
56
|
+
data=json.dumps(request),
|
|
57
|
+
timeout=timeout,
|
|
58
|
+
headers={"Content-Type": "application/json"},
|
|
59
|
+
)
|
|
60
|
+
# retrying 502 errors
|
|
61
|
+
if output.status_code == 502:
|
|
62
|
+
raise requests.exceptions.Timeout
|
|
63
|
+
|
|
64
|
+
return self._parse_request_output(output)
|
|
65
|
+
|
|
66
|
+
@abc.abstractmethod
|
|
67
|
+
def _parse_request_output(self, output):
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
@abc.abstractmethod
|
|
71
|
+
def _get_execute_url(self, uri):
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
@abc.abstractmethod
|
|
75
|
+
def _prepare_request(self, generated_code, timeout):
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
async def execute_code(
|
|
79
|
+
self,
|
|
80
|
+
generated_code: str,
|
|
81
|
+
timeout: float = 10.0,
|
|
82
|
+
language: str = "python",
|
|
83
|
+
max_output_characters: int = 1000,
|
|
84
|
+
) -> tuple[dict, str]:
|
|
85
|
+
|
|
86
|
+
generated_code = generated_code.lstrip().rstrip().lstrip("`").rstrip("`")
|
|
87
|
+
code_to_execute = """
|
|
88
|
+
import traceback
|
|
89
|
+
import json
|
|
90
|
+
import os
|
|
91
|
+
import warnings
|
|
92
|
+
import contextlib
|
|
93
|
+
import io
|
|
94
|
+
warnings.filterwarnings('ignore')
|
|
95
|
+
os.environ['OPENBLAS_NUM_THREADS'] = '16'
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
code_to_execute += f"""
|
|
99
|
+
\ngenerated_code = {repr(generated_code)}\n
|
|
100
|
+
stdout = io.StringIO()
|
|
101
|
+
stderr = io.StringIO()
|
|
102
|
+
|
|
103
|
+
with contextlib.redirect_stdout(stdout), contextlib.redirect_stderr(stderr):
|
|
104
|
+
try:
|
|
105
|
+
exec(generated_code)
|
|
106
|
+
status = "completed"
|
|
107
|
+
except Exception:
|
|
108
|
+
status = "error"
|
|
109
|
+
stderr.write(traceback.format_exc())
|
|
110
|
+
stdout = stdout.getvalue()
|
|
111
|
+
stderr = stderr.getvalue()
|
|
112
|
+
if len(stdout) > {max_output_characters}:
|
|
113
|
+
stdout = stdout[:{max_output_characters}] + "<output cut>"
|
|
114
|
+
if len(stderr) > {max_output_characters}:
|
|
115
|
+
stderr = stderr[:{max_output_characters}] + "<output cut>"
|
|
116
|
+
if stdout:
|
|
117
|
+
stdout += "\\n"
|
|
118
|
+
if stderr:
|
|
119
|
+
stderr += "\\n"
|
|
120
|
+
output = {{"process_status": status, "stdout": stdout, "stderr": stderr}}
|
|
121
|
+
print(json.dumps(output))
|
|
122
|
+
"""
|
|
123
|
+
request = self._prepare_request(code_to_execute, timeout)
|
|
124
|
+
try:
|
|
125
|
+
output = self._send_request(request, timeout)
|
|
126
|
+
except requests.exceptions.Timeout:
|
|
127
|
+
output = {"process_status": "timeout", "stdout": "", "stderr": "Timed out\n"}
|
|
128
|
+
return output
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class LocalSandbox(Sandbox):
|
|
132
|
+
"""Locally hosted sandbox."""
|
|
133
|
+
|
|
134
|
+
def _get_execute_url(self, uri):
|
|
135
|
+
return urljoin(str(uri), "execute")
|
|
136
|
+
|
|
137
|
+
def _parse_request_output(self, output):
|
|
138
|
+
try:
|
|
139
|
+
return output.json()
|
|
140
|
+
except json.JSONDecodeError as e:
|
|
141
|
+
logger.exception("Error parsing output: %s. %s", output.text, e)
|
|
142
|
+
return {'process_status': 'error', 'stdout': '', 'stderr': 'Unknown error'}
|
|
143
|
+
|
|
144
|
+
def _prepare_request(self, generated_code, timeout, language='python', **kwargs):
|
|
145
|
+
return {
|
|
146
|
+
"generated_code": generated_code,
|
|
147
|
+
"timeout": timeout,
|
|
148
|
+
"language": language,
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class PistonSandbox(Sandbox):
|
|
153
|
+
"""Piston sandbox (https://github.com/engineer-man/piston)"""
|
|
154
|
+
|
|
155
|
+
def _get_execute_url(self, uri):
|
|
156
|
+
return urljoin(str(uri), "execute")
|
|
157
|
+
|
|
158
|
+
def _parse_request_output(self, output):
|
|
159
|
+
output = output.json()
|
|
160
|
+
if output['run']['signal'] == "SIGKILL":
|
|
161
|
+
return {'result': None, 'error_message': 'Unknown error: SIGKILL'}
|
|
162
|
+
return json.loads(output['run']['output'])
|
|
163
|
+
|
|
164
|
+
def _prepare_request(self, generated_code: str, timeout, **kwargs):
|
|
165
|
+
return {
|
|
166
|
+
"language": "py",
|
|
167
|
+
"version": "3.10.0",
|
|
168
|
+
"files": [{
|
|
169
|
+
"content": generated_code,
|
|
170
|
+
}],
|
|
171
|
+
"stdin": "",
|
|
172
|
+
"args": [],
|
|
173
|
+
"run_timeout": timeout * 1000.0, # milliseconds
|
|
174
|
+
"compile_memory_limit": -1,
|
|
175
|
+
"run_memory_limit": -1,
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
sandboxes = {
|
|
180
|
+
'local': LocalSandbox,
|
|
181
|
+
'piston': PistonSandbox,
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def get_sandbox(sandbox_type: str = "local", **kwargs):
|
|
186
|
+
"""A helper function to make it easier to set sandbox through cmd."""
|
|
187
|
+
sandbox_class = sandboxes[sandbox_type.lower()]
|
|
188
|
+
return sandbox_class(**kwargs)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# Use the base image with Python 3.10 and Flask
|
|
16
|
+
FROM tiangolo/uwsgi-nginx-flask:python3.10
|
|
17
|
+
|
|
18
|
+
# Install dependencies required for Lean 4 and other tools
|
|
19
|
+
RUN apt-get update && \
|
|
20
|
+
apt-get install -y curl git && \
|
|
21
|
+
curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | sh -s -- -y && \
|
|
22
|
+
/root/.elan/bin/elan toolchain install leanprover/lean4:v4.12.0 && \
|
|
23
|
+
/root/.elan/bin/elan default leanprover/lean4:v4.12.0 && \
|
|
24
|
+
/root/.elan/bin/elan self update
|
|
25
|
+
|
|
26
|
+
# Set environment variables to include Lean and elan/lake in the PATH
|
|
27
|
+
ENV PATH="/root/.elan/bin:$PATH"
|
|
28
|
+
|
|
29
|
+
# Create Lean project directory and initialize a new Lean project with Mathlib4
|
|
30
|
+
RUN mkdir -p /lean4 && cd /lean4 && \
|
|
31
|
+
/root/.elan/bin/lake new my_project && \
|
|
32
|
+
cd my_project && \
|
|
33
|
+
echo 'leanprover/lean4:v4.12.0' > lean-toolchain && \
|
|
34
|
+
echo 'require mathlib from git "https://github.com/leanprover-community/mathlib4" @ "v4.12.0"' >> lakefile.lean
|
|
35
|
+
|
|
36
|
+
# Download and cache Mathlib4 to avoid recompiling, then build the project
|
|
37
|
+
RUN cd /lean4/my_project && \
|
|
38
|
+
/root/.elan/bin/lake exe cache get && \
|
|
39
|
+
/root/.elan/bin/lake build
|
|
40
|
+
|
|
41
|
+
# Set environment variables to include Lean project path
|
|
42
|
+
ENV LEAN_PATH="/lean4/my_project"
|
|
43
|
+
ENV PATH="/lean4/my_project:$PATH"
|
|
44
|
+
|
|
45
|
+
# Set up application code and install Python dependencies
|
|
46
|
+
COPY sandbox.requirements.txt /app/requirements.txt
|
|
47
|
+
RUN pip install --no-cache-dir -r /app/requirements.txt
|
|
48
|
+
COPY local_sandbox_server.py /app/main.py
|
|
49
|
+
|
|
50
|
+
# Set the working directory to /app
|
|
51
|
+
WORKDIR /app
|
|
52
|
+
|
|
53
|
+
# Set Flask app environment variables and ports
|
|
54
|
+
ARG UWSGI_CHEAPER
|
|
55
|
+
ENV UWSGI_CHEAPER=$UWSGI_CHEAPER
|
|
56
|
+
|
|
57
|
+
ARG UWSGI_PROCESSES
|
|
58
|
+
ENV UWSGI_PROCESSES=$UWSGI_PROCESSES
|
|
59
|
+
|
|
60
|
+
ENV LISTEN_PORT=6000
|