aiqtoolkit 1.1.0a20250429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/__init__.py +0 -0
- aiq/agent/base.py +76 -0
- aiq/agent/dual_node.py +67 -0
- aiq/agent/react_agent/__init__.py +0 -0
- aiq/agent/react_agent/agent.py +322 -0
- aiq/agent/react_agent/output_parser.py +104 -0
- aiq/agent/react_agent/prompt.py +46 -0
- aiq/agent/react_agent/register.py +148 -0
- aiq/agent/reasoning_agent/__init__.py +0 -0
- aiq/agent/reasoning_agent/reasoning_agent.py +224 -0
- aiq/agent/register.py +23 -0
- aiq/agent/rewoo_agent/__init__.py +0 -0
- aiq/agent/rewoo_agent/agent.py +410 -0
- aiq/agent/rewoo_agent/prompt.py +108 -0
- aiq/agent/rewoo_agent/register.py +158 -0
- aiq/agent/tool_calling_agent/__init__.py +0 -0
- aiq/agent/tool_calling_agent/agent.py +123 -0
- aiq/agent/tool_calling_agent/register.py +105 -0
- aiq/builder/__init__.py +0 -0
- aiq/builder/builder.py +223 -0
- aiq/builder/component_utils.py +303 -0
- aiq/builder/context.py +198 -0
- aiq/builder/embedder.py +24 -0
- aiq/builder/eval_builder.py +116 -0
- aiq/builder/evaluator.py +29 -0
- aiq/builder/framework_enum.py +24 -0
- aiq/builder/front_end.py +73 -0
- aiq/builder/function.py +297 -0
- aiq/builder/function_base.py +372 -0
- aiq/builder/function_info.py +627 -0
- aiq/builder/intermediate_step_manager.py +125 -0
- aiq/builder/llm.py +25 -0
- aiq/builder/retriever.py +25 -0
- aiq/builder/user_interaction_manager.py +71 -0
- aiq/builder/workflow.py +134 -0
- aiq/builder/workflow_builder.py +733 -0
- aiq/cli/__init__.py +14 -0
- aiq/cli/cli_utils/__init__.py +0 -0
- aiq/cli/cli_utils/config_override.py +233 -0
- aiq/cli/cli_utils/validation.py +37 -0
- aiq/cli/commands/__init__.py +0 -0
- aiq/cli/commands/configure/__init__.py +0 -0
- aiq/cli/commands/configure/channel/__init__.py +0 -0
- aiq/cli/commands/configure/channel/add.py +28 -0
- aiq/cli/commands/configure/channel/channel.py +34 -0
- aiq/cli/commands/configure/channel/remove.py +30 -0
- aiq/cli/commands/configure/channel/update.py +30 -0
- aiq/cli/commands/configure/configure.py +33 -0
- aiq/cli/commands/evaluate.py +139 -0
- aiq/cli/commands/info/__init__.py +14 -0
- aiq/cli/commands/info/info.py +37 -0
- aiq/cli/commands/info/list_channels.py +32 -0
- aiq/cli/commands/info/list_components.py +129 -0
- aiq/cli/commands/registry/__init__.py +14 -0
- aiq/cli/commands/registry/publish.py +88 -0
- aiq/cli/commands/registry/pull.py +118 -0
- aiq/cli/commands/registry/registry.py +36 -0
- aiq/cli/commands/registry/remove.py +108 -0
- aiq/cli/commands/registry/search.py +155 -0
- aiq/cli/commands/start.py +250 -0
- aiq/cli/commands/uninstall.py +83 -0
- aiq/cli/commands/validate.py +47 -0
- aiq/cli/commands/workflow/__init__.py +14 -0
- aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
- aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
- aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- aiq/cli/commands/workflow/workflow.py +37 -0
- aiq/cli/commands/workflow/workflow_commands.py +307 -0
- aiq/cli/entrypoint.py +133 -0
- aiq/cli/main.py +44 -0
- aiq/cli/register_workflow.py +408 -0
- aiq/cli/type_registry.py +869 -0
- aiq/data_models/__init__.py +14 -0
- aiq/data_models/api_server.py +550 -0
- aiq/data_models/common.py +143 -0
- aiq/data_models/component.py +46 -0
- aiq/data_models/component_ref.py +135 -0
- aiq/data_models/config.py +349 -0
- aiq/data_models/dataset_handler.py +122 -0
- aiq/data_models/discovery_metadata.py +269 -0
- aiq/data_models/embedder.py +26 -0
- aiq/data_models/evaluate.py +101 -0
- aiq/data_models/evaluator.py +26 -0
- aiq/data_models/front_end.py +26 -0
- aiq/data_models/function.py +30 -0
- aiq/data_models/function_dependencies.py +64 -0
- aiq/data_models/interactive.py +237 -0
- aiq/data_models/intermediate_step.py +269 -0
- aiq/data_models/invocation_node.py +38 -0
- aiq/data_models/llm.py +26 -0
- aiq/data_models/logging.py +26 -0
- aiq/data_models/memory.py +26 -0
- aiq/data_models/profiler.py +53 -0
- aiq/data_models/registry_handler.py +26 -0
- aiq/data_models/retriever.py +30 -0
- aiq/data_models/step_adaptor.py +64 -0
- aiq/data_models/streaming.py +33 -0
- aiq/data_models/swe_bench_model.py +54 -0
- aiq/data_models/telemetry_exporter.py +26 -0
- aiq/embedder/__init__.py +0 -0
- aiq/embedder/langchain_client.py +41 -0
- aiq/embedder/nim_embedder.py +58 -0
- aiq/embedder/openai_embedder.py +42 -0
- aiq/embedder/register.py +24 -0
- aiq/eval/__init__.py +14 -0
- aiq/eval/config.py +42 -0
- aiq/eval/dataset_handler/__init__.py +0 -0
- aiq/eval/dataset_handler/dataset_downloader.py +106 -0
- aiq/eval/dataset_handler/dataset_filter.py +52 -0
- aiq/eval/dataset_handler/dataset_handler.py +164 -0
- aiq/eval/evaluate.py +322 -0
- aiq/eval/evaluator/__init__.py +14 -0
- aiq/eval/evaluator/evaluator_model.py +44 -0
- aiq/eval/intermediate_step_adapter.py +93 -0
- aiq/eval/rag_evaluator/__init__.py +0 -0
- aiq/eval/rag_evaluator/evaluate.py +138 -0
- aiq/eval/rag_evaluator/register.py +138 -0
- aiq/eval/register.py +22 -0
- aiq/eval/remote_workflow.py +128 -0
- aiq/eval/runtime_event_subscriber.py +52 -0
- aiq/eval/swe_bench_evaluator/__init__.py +0 -0
- aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
- aiq/eval/swe_bench_evaluator/register.py +36 -0
- aiq/eval/trajectory_evaluator/__init__.py +0 -0
- aiq/eval/trajectory_evaluator/evaluate.py +118 -0
- aiq/eval/trajectory_evaluator/register.py +40 -0
- aiq/eval/utils/__init__.py +0 -0
- aiq/eval/utils/output_uploader.py +131 -0
- aiq/eval/utils/tqdm_position_registry.py +40 -0
- aiq/front_ends/__init__.py +14 -0
- aiq/front_ends/console/__init__.py +14 -0
- aiq/front_ends/console/console_front_end_config.py +32 -0
- aiq/front_ends/console/console_front_end_plugin.py +107 -0
- aiq/front_ends/console/register.py +25 -0
- aiq/front_ends/cron/__init__.py +14 -0
- aiq/front_ends/fastapi/__init__.py +14 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +150 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +103 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +574 -0
- aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- aiq/front_ends/fastapi/job_store.py +161 -0
- aiq/front_ends/fastapi/main.py +70 -0
- aiq/front_ends/fastapi/message_handler.py +279 -0
- aiq/front_ends/fastapi/message_validator.py +345 -0
- aiq/front_ends/fastapi/register.py +25 -0
- aiq/front_ends/fastapi/response_helpers.py +181 -0
- aiq/front_ends/fastapi/step_adaptor.py +315 -0
- aiq/front_ends/fastapi/websocket.py +148 -0
- aiq/front_ends/mcp/__init__.py +14 -0
- aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
- aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
- aiq/front_ends/mcp/register.py +27 -0
- aiq/front_ends/mcp/tool_converter.py +242 -0
- aiq/front_ends/register.py +22 -0
- aiq/front_ends/simple_base/__init__.py +14 -0
- aiq/front_ends/simple_base/simple_front_end_plugin_base.py +52 -0
- aiq/llm/__init__.py +0 -0
- aiq/llm/nim_llm.py +45 -0
- aiq/llm/openai_llm.py +45 -0
- aiq/llm/register.py +22 -0
- aiq/llm/utils/__init__.py +14 -0
- aiq/llm/utils/env_config_value.py +94 -0
- aiq/llm/utils/error.py +17 -0
- aiq/memory/__init__.py +20 -0
- aiq/memory/interfaces.py +183 -0
- aiq/memory/models.py +102 -0
- aiq/meta/module_to_distro.json +3 -0
- aiq/meta/pypi.md +59 -0
- aiq/observability/__init__.py +0 -0
- aiq/observability/async_otel_listener.py +270 -0
- aiq/observability/register.py +97 -0
- aiq/plugins/.namespace +1 -0
- aiq/profiler/__init__.py +0 -0
- aiq/profiler/callbacks/__init__.py +0 -0
- aiq/profiler/callbacks/agno_callback_handler.py +295 -0
- aiq/profiler/callbacks/base_callback_class.py +20 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +278 -0
- aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
- aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- aiq/profiler/callbacks/token_usage_base_model.py +27 -0
- aiq/profiler/data_frame_row.py +51 -0
- aiq/profiler/decorators/__init__.py +0 -0
- aiq/profiler/decorators/framework_wrapper.py +131 -0
- aiq/profiler/decorators/function_tracking.py +254 -0
- aiq/profiler/forecasting/__init__.py +0 -0
- aiq/profiler/forecasting/config.py +18 -0
- aiq/profiler/forecasting/model_trainer.py +75 -0
- aiq/profiler/forecasting/models/__init__.py +22 -0
- aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
- aiq/profiler/forecasting/models/linear_model.py +196 -0
- aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
- aiq/profiler/inference_metrics_model.py +25 -0
- aiq/profiler/inference_optimization/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +452 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- aiq/profiler/inference_optimization/data_models.py +386 -0
- aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
- aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- aiq/profiler/inference_optimization/llm_metrics.py +212 -0
- aiq/profiler/inference_optimization/prompt_caching.py +163 -0
- aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
- aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
- aiq/profiler/intermediate_property_adapter.py +102 -0
- aiq/profiler/profile_runner.py +433 -0
- aiq/profiler/utils.py +184 -0
- aiq/registry_handlers/__init__.py +0 -0
- aiq/registry_handlers/local/__init__.py +0 -0
- aiq/registry_handlers/local/local_handler.py +176 -0
- aiq/registry_handlers/local/register_local.py +37 -0
- aiq/registry_handlers/metadata_factory.py +60 -0
- aiq/registry_handlers/package_utils.py +198 -0
- aiq/registry_handlers/pypi/__init__.py +0 -0
- aiq/registry_handlers/pypi/pypi_handler.py +251 -0
- aiq/registry_handlers/pypi/register_pypi.py +40 -0
- aiq/registry_handlers/register.py +21 -0
- aiq/registry_handlers/registry_handler_base.py +157 -0
- aiq/registry_handlers/rest/__init__.py +0 -0
- aiq/registry_handlers/rest/register_rest.py +56 -0
- aiq/registry_handlers/rest/rest_handler.py +237 -0
- aiq/registry_handlers/schemas/__init__.py +0 -0
- aiq/registry_handlers/schemas/headers.py +42 -0
- aiq/registry_handlers/schemas/package.py +68 -0
- aiq/registry_handlers/schemas/publish.py +63 -0
- aiq/registry_handlers/schemas/pull.py +81 -0
- aiq/registry_handlers/schemas/remove.py +36 -0
- aiq/registry_handlers/schemas/search.py +91 -0
- aiq/registry_handlers/schemas/status.py +47 -0
- aiq/retriever/__init__.py +0 -0
- aiq/retriever/interface.py +37 -0
- aiq/retriever/milvus/__init__.py +14 -0
- aiq/retriever/milvus/register.py +81 -0
- aiq/retriever/milvus/retriever.py +228 -0
- aiq/retriever/models.py +74 -0
- aiq/retriever/nemo_retriever/__init__.py +14 -0
- aiq/retriever/nemo_retriever/register.py +60 -0
- aiq/retriever/nemo_retriever/retriever.py +190 -0
- aiq/retriever/register.py +22 -0
- aiq/runtime/__init__.py +14 -0
- aiq/runtime/loader.py +188 -0
- aiq/runtime/runner.py +176 -0
- aiq/runtime/session.py +116 -0
- aiq/settings/__init__.py +0 -0
- aiq/settings/global_settings.py +318 -0
- aiq/test/.namespace +1 -0
- aiq/tool/__init__.py +0 -0
- aiq/tool/code_execution/__init__.py +0 -0
- aiq/tool/code_execution/code_sandbox.py +188 -0
- aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +79 -0
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +4 -0
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +25 -0
- aiq/tool/code_execution/register.py +70 -0
- aiq/tool/code_execution/utils.py +100 -0
- aiq/tool/datetime_tools.py +42 -0
- aiq/tool/document_search.py +141 -0
- aiq/tool/github_tools/__init__.py +0 -0
- aiq/tool/github_tools/create_github_commit.py +133 -0
- aiq/tool/github_tools/create_github_issue.py +87 -0
- aiq/tool/github_tools/create_github_pr.py +106 -0
- aiq/tool/github_tools/get_github_file.py +106 -0
- aiq/tool/github_tools/get_github_issue.py +166 -0
- aiq/tool/github_tools/get_github_pr.py +256 -0
- aiq/tool/github_tools/update_github_issue.py +100 -0
- aiq/tool/mcp/__init__.py +14 -0
- aiq/tool/mcp/mcp_client.py +220 -0
- aiq/tool/mcp/mcp_tool.py +75 -0
- aiq/tool/memory_tools/__init__.py +0 -0
- aiq/tool/memory_tools/add_memory_tool.py +67 -0
- aiq/tool/memory_tools/delete_memory_tool.py +67 -0
- aiq/tool/memory_tools/get_memory_tool.py +72 -0
- aiq/tool/nvidia_rag.py +95 -0
- aiq/tool/register.py +36 -0
- aiq/tool/retriever.py +89 -0
- aiq/utils/__init__.py +0 -0
- aiq/utils/data_models/__init__.py +0 -0
- aiq/utils/data_models/schema_validator.py +58 -0
- aiq/utils/debugging_utils.py +43 -0
- aiq/utils/exception_handlers/__init__.py +0 -0
- aiq/utils/exception_handlers/schemas.py +114 -0
- aiq/utils/io/__init__.py +0 -0
- aiq/utils/io/yaml_tools.py +50 -0
- aiq/utils/metadata_utils.py +74 -0
- aiq/utils/producer_consumer_queue.py +178 -0
- aiq/utils/reactive/__init__.py +0 -0
- aiq/utils/reactive/base/__init__.py +0 -0
- aiq/utils/reactive/base/observable_base.py +65 -0
- aiq/utils/reactive/base/observer_base.py +55 -0
- aiq/utils/reactive/base/subject_base.py +79 -0
- aiq/utils/reactive/observable.py +59 -0
- aiq/utils/reactive/observer.py +76 -0
- aiq/utils/reactive/subject.py +131 -0
- aiq/utils/reactive/subscription.py +49 -0
- aiq/utils/settings/__init__.py +0 -0
- aiq/utils/settings/global_settings.py +197 -0
- aiq/utils/type_converter.py +232 -0
- aiq/utils/type_utils.py +397 -0
- aiq/utils/url_utils.py +27 -0
- aiqtoolkit-1.1.0a20250429.dist-info/METADATA +326 -0
- aiqtoolkit-1.1.0a20250429.dist-info/RECORD +309 -0
- aiqtoolkit-1.1.0a20250429.dist-info/WHEEL +5 -0
- aiqtoolkit-1.1.0a20250429.dist-info/entry_points.txt +17 -0
- aiqtoolkit-1.1.0a20250429.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
- aiqtoolkit-1.1.0a20250429.dist-info/licenses/LICENSE.md +201 -0
- aiqtoolkit-1.1.0a20250429.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from aiq.builder.builder import Builder
|
|
21
|
+
from aiq.builder.function_info import FunctionInfo
|
|
22
|
+
from aiq.cli.register_workflow import register_function
|
|
23
|
+
from aiq.data_models.component_ref import MemoryRef
|
|
24
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
25
|
+
from aiq.memory.models import MemoryItem
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AddToolConfig(FunctionBaseConfig, name="add_memory"):
|
|
31
|
+
"""Function to add memory to a hosted memory platform."""
|
|
32
|
+
|
|
33
|
+
description: str = Field(default=("Tool to add memory about a user's interactions to a system "
|
|
34
|
+
"for retrieval later."),
|
|
35
|
+
description="The description of this function's use for tool calling agents.")
|
|
36
|
+
memory: MemoryRef = Field(default="saas_memory",
|
|
37
|
+
description=("Instance name of the memory client instance from the workflow "
|
|
38
|
+
"configuration object."))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@register_function(config_type=AddToolConfig)
|
|
42
|
+
async def add_memory_tool(config: AddToolConfig, builder: Builder):
|
|
43
|
+
"""
|
|
44
|
+
Function to add memory to a hosted memory platform.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
from langchain_core.tools import ToolException
|
|
48
|
+
|
|
49
|
+
# First, retrieve the memory client
|
|
50
|
+
memory_editor = builder.get_memory_client(config.memory)
|
|
51
|
+
|
|
52
|
+
async def _arun(item: MemoryItem) -> str:
|
|
53
|
+
"""
|
|
54
|
+
Asynchronous execution of addition of memories.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
|
|
59
|
+
await memory_editor.add_items([item])
|
|
60
|
+
|
|
61
|
+
return "Memory added successfully. You can continue. Please respond to the user."
|
|
62
|
+
|
|
63
|
+
except Exception as e:
|
|
64
|
+
|
|
65
|
+
raise ToolException(f"Error adding memory: {e}") from e
|
|
66
|
+
|
|
67
|
+
yield FunctionInfo.from_fn(_arun, description=config.description)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from aiq.builder.builder import Builder
|
|
21
|
+
from aiq.builder.function_info import FunctionInfo
|
|
22
|
+
from aiq.cli.register_workflow import register_function
|
|
23
|
+
from aiq.data_models.component_ref import MemoryRef
|
|
24
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
25
|
+
from aiq.memory.models import DeleteMemoryInput
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DeleteToolConfig(FunctionBaseConfig, name="delete_memory"):
|
|
31
|
+
"""Function to delete memory from a hosted memory platform."""
|
|
32
|
+
|
|
33
|
+
description: str = Field(default=("Tool to retrieve memory about a user's "
|
|
34
|
+
"interactions to help answer questions in a personalized way."),
|
|
35
|
+
description="The description of this function's use for tool calling agents.")
|
|
36
|
+
memory: MemoryRef = Field(default="saas_memory",
|
|
37
|
+
description=("Instance name of the memory client instance from the workflow "
|
|
38
|
+
"configuration object."))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@register_function(config_type=DeleteToolConfig)
|
|
42
|
+
async def delete_memory_tool(config: DeleteToolConfig, builder: Builder):
|
|
43
|
+
"""
|
|
44
|
+
Function to delete memory from a hosted memory platform.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
from langchain_core.tools import ToolException
|
|
48
|
+
|
|
49
|
+
# First, retrieve the memory client
|
|
50
|
+
memory_editor = builder.get_memory_client(config.memory)
|
|
51
|
+
|
|
52
|
+
async def _arun(user_id: str) -> str:
|
|
53
|
+
"""
|
|
54
|
+
Asynchronous execution of deletion of memories.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
|
|
59
|
+
await memory_editor.remove_items(user_id=user_id, )
|
|
60
|
+
|
|
61
|
+
return "Memories deleted!"
|
|
62
|
+
|
|
63
|
+
except Exception as e:
|
|
64
|
+
|
|
65
|
+
raise ToolException(f"Error deleting memory: {e}") from e
|
|
66
|
+
|
|
67
|
+
yield FunctionInfo.from_fn(_arun, description=config.description, input_schema=DeleteMemoryInput)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from aiq.builder.builder import Builder
|
|
21
|
+
from aiq.builder.function_info import FunctionInfo
|
|
22
|
+
from aiq.cli.register_workflow import register_function
|
|
23
|
+
from aiq.data_models.component_ref import MemoryRef
|
|
24
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
25
|
+
from aiq.memory.models import SearchMemoryInput
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GetToolConfig(FunctionBaseConfig, name="get_memory"):
|
|
31
|
+
"""Function to get memory to a hosted memory platform."""
|
|
32
|
+
|
|
33
|
+
description: str = Field(default=("Tool to retrieve memory about a user's "
|
|
34
|
+
"interactions to help answer questions in a personalized way."),
|
|
35
|
+
description="The description of this function's use for tool calling agents.")
|
|
36
|
+
memory: MemoryRef = Field(default="saas_memory",
|
|
37
|
+
description=("Instance name of the memory client instance from the workflow "
|
|
38
|
+
"configuration object."))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@register_function(config_type=GetToolConfig)
|
|
42
|
+
async def get_memory_tool(config: GetToolConfig, builder: Builder):
|
|
43
|
+
"""
|
|
44
|
+
Function to get memory to a hosted memory platform.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
import json
|
|
48
|
+
|
|
49
|
+
from langchain_core.tools import ToolException
|
|
50
|
+
|
|
51
|
+
# First, retrieve the memory client
|
|
52
|
+
memory_editor = builder.get_memory_client(config.memory)
|
|
53
|
+
|
|
54
|
+
async def _arun(search_input: SearchMemoryInput) -> str:
|
|
55
|
+
"""
|
|
56
|
+
Asynchronous execution of collection of memories.
|
|
57
|
+
"""
|
|
58
|
+
try:
|
|
59
|
+
memories = await memory_editor.search(
|
|
60
|
+
query=search_input.query,
|
|
61
|
+
top_k=search_input.top_k,
|
|
62
|
+
user_id=search_input.user_id,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
memory_str = f"Memories as a JSON: \n{json.dumps([mem.model_dump(mode='json') for mem in memories])}"
|
|
66
|
+
return memory_str
|
|
67
|
+
|
|
68
|
+
except Exception as e:
|
|
69
|
+
|
|
70
|
+
raise ToolException(f"Error retreiving memory: {e}") from e
|
|
71
|
+
|
|
72
|
+
yield FunctionInfo.from_fn(_arun, description=config.description)
|
aiq/tool/nvidia_rag.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from aiq.builder.builder import Builder
|
|
22
|
+
from aiq.builder.function_info import FunctionInfo
|
|
23
|
+
from aiq.cli.register_workflow import register_function
|
|
24
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class NVIDIARAGToolConfig(FunctionBaseConfig, name="nvidia_rag"):
|
|
30
|
+
"""
|
|
31
|
+
Tool used to search the NVIDIA Developer database for documents across a variety of NVIDIA asset types.
|
|
32
|
+
"""
|
|
33
|
+
base_url: str = Field(description="The base url to the RAG service.")
|
|
34
|
+
timeout: int = Field(default=60, description="The timeout configuration to use when sending requests.")
|
|
35
|
+
document_separator: str = Field(default="\n\n", description="The delimiter to use between retrieved documents.")
|
|
36
|
+
document_prompt: str = Field(default=("-------\n\n" + "Title: {document_title}\n"
|
|
37
|
+
"Text: {page_content}\nSource URL: {document_url}"),
|
|
38
|
+
description="The prompt to use to retrieve documents from the RAG service")
|
|
39
|
+
top_k: int = Field(default=4, description="The number of results to return from the RAG service.")
|
|
40
|
+
collection_name: str = Field(default="nvidia_api_catalog",
|
|
41
|
+
description=("The name of the collection to use when retrieving documents."))
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@register_function(config_type=NVIDIARAGToolConfig)
|
|
45
|
+
async def nvidia_rag_tool(config: NVIDIARAGToolConfig, builder: Builder):
|
|
46
|
+
import httpx
|
|
47
|
+
from langchain.prompts import PromptTemplate
|
|
48
|
+
from langchain_core.documents import Document
|
|
49
|
+
from langchain_core.prompts import aformat_document
|
|
50
|
+
|
|
51
|
+
document_prompt = PromptTemplate.from_template(config.document_prompt)
|
|
52
|
+
|
|
53
|
+
async with httpx.AsyncClient(headers={
|
|
54
|
+
"accept": "application/json", "Content-Type": "application/json"
|
|
55
|
+
},
|
|
56
|
+
timeout=config.timeout) as client:
|
|
57
|
+
|
|
58
|
+
async def runnable(query: str) -> str:
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
url = f"{config.base_url}/search"
|
|
62
|
+
|
|
63
|
+
payload = {"query": query, "top_k": config.top_k, "collection_name": config.collection_name}
|
|
64
|
+
|
|
65
|
+
logger.debug("Sending request to the RAG endpoint %s.", url)
|
|
66
|
+
response = await client.post(url, content=json.dumps(payload))
|
|
67
|
+
|
|
68
|
+
response.raise_for_status()
|
|
69
|
+
|
|
70
|
+
output = response.json()
|
|
71
|
+
|
|
72
|
+
docs = [
|
|
73
|
+
Document(
|
|
74
|
+
page_content=ret["content"],
|
|
75
|
+
metadata={
|
|
76
|
+
"document_title": ret["filename"],
|
|
77
|
+
"document_url": "nemo_framework",
|
|
78
|
+
"document_full_text": ret["content"],
|
|
79
|
+
"score_rerank": ret["score"]
|
|
80
|
+
},
|
|
81
|
+
type="Document",
|
|
82
|
+
) for ret in output["chunks"]
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
parsed_output = config.document_separator.join(
|
|
86
|
+
[await aformat_document(doc, document_prompt) for doc in docs])
|
|
87
|
+
return parsed_output
|
|
88
|
+
except Exception as e:
|
|
89
|
+
logger.exception("Error while running the tool", exc_info=True)
|
|
90
|
+
return f"Error while running the tool: {e}"
|
|
91
|
+
|
|
92
|
+
yield FunctionInfo.from_fn(
|
|
93
|
+
runnable,
|
|
94
|
+
description=("Search the NVIDIA Developer database for documents across a variety of "
|
|
95
|
+
"NVIDIA asset types"))
|
aiq/tool/register.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# pylint: disable=unused-import
|
|
17
|
+
# flake8: noqa
|
|
18
|
+
|
|
19
|
+
# Import any tools which need to be automatically registered here
|
|
20
|
+
from . import datetime_tools
|
|
21
|
+
from . import document_search
|
|
22
|
+
from . import github_tools
|
|
23
|
+
from . import nvidia_rag
|
|
24
|
+
from . import retriever
|
|
25
|
+
from .code_execution import register
|
|
26
|
+
from .github_tools import create_github_commit
|
|
27
|
+
from .github_tools import create_github_issue
|
|
28
|
+
from .github_tools import create_github_pr
|
|
29
|
+
from .github_tools import get_github_file
|
|
30
|
+
from .github_tools import get_github_issue
|
|
31
|
+
from .github_tools import get_github_pr
|
|
32
|
+
from .github_tools import update_github_issue
|
|
33
|
+
from .mcp import mcp_tool
|
|
34
|
+
from .memory_tools import add_memory_tool
|
|
35
|
+
from .memory_tools import delete_memory_tool
|
|
36
|
+
from .memory_tools import get_memory_tool
|
aiq/tool/retriever.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from aiq.builder.builder import Builder
|
|
22
|
+
from aiq.builder.function_info import FunctionInfo
|
|
23
|
+
from aiq.cli.register_workflow import register_function
|
|
24
|
+
from aiq.data_models.component_ref import RetrieverRef
|
|
25
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
26
|
+
from aiq.retriever.interface import AIQRetriever
|
|
27
|
+
from aiq.retriever.models import RetrieverError
|
|
28
|
+
from aiq.retriever.models import RetrieverOutput
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AIQRetrieverConfig(FunctionBaseConfig, name="aiq_retriever"):
|
|
34
|
+
"""
|
|
35
|
+
AIQRetriever tool which provides a common interface for different vectorstores. Its
|
|
36
|
+
configuration uses clients, which are the vectorstore-specific implementaiton of the retriever interface.
|
|
37
|
+
"""
|
|
38
|
+
retriever: RetrieverRef = Field(description="The retriever instance name from the workflow configuration object.")
|
|
39
|
+
raise_errors: bool = Field(
|
|
40
|
+
default=True,
|
|
41
|
+
description="If true the tool will raise exceptions, otherwise it will log them as warnings and return []",
|
|
42
|
+
)
|
|
43
|
+
topic: str | None = Field(default=None, description="Used to provide a more detailed tool description to the agent")
|
|
44
|
+
description: str | None = Field(default=None, description="If present it will be used as the tool description")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _get_description_from_config(config: AIQRetrieverConfig) -> str:
|
|
48
|
+
"""
|
|
49
|
+
Generate a description of what the tool will do based on how it is configured.
|
|
50
|
+
"""
|
|
51
|
+
description = "Retrieve document chunks{topic} which can be used to answer the provided question."
|
|
52
|
+
|
|
53
|
+
_topic = f" related to {config.topic}" if config.topic else ""
|
|
54
|
+
|
|
55
|
+
return description.format(topic=_topic) if not config.description else config.description
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@register_function(config_type=AIQRetrieverConfig)
|
|
59
|
+
async def aiq_retriever_tool(config: AIQRetrieverConfig, builder: Builder):
|
|
60
|
+
"""
|
|
61
|
+
Configure an AgentIQ Retriever Tool which supports different clients such as Milvus and Nemo Retriever.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
config: A config object with required parameters 'client' and 'client_config'
|
|
65
|
+
builder: A workflow builder object
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
class RetrieverInputSchema(BaseModel):
|
|
69
|
+
query: str = Field(description="The query to be searched in the configured data store")
|
|
70
|
+
|
|
71
|
+
client: AIQRetriever = await builder.get_retriever(config.retriever)
|
|
72
|
+
|
|
73
|
+
async def _retrieve(query: str) -> RetrieverOutput:
|
|
74
|
+
try:
|
|
75
|
+
retrieved_context = await client.search(query=query)
|
|
76
|
+
logger.info("Retrieved %s records for query %s.", len(retrieved_context), query)
|
|
77
|
+
return retrieved_context
|
|
78
|
+
|
|
79
|
+
except RetrieverError as e:
|
|
80
|
+
if config.raise_errors:
|
|
81
|
+
raise e
|
|
82
|
+
logger.warning("Retriever threw an error: %s. Returning an empty response.", e)
|
|
83
|
+
return RetrieverOutput(results=[])
|
|
84
|
+
|
|
85
|
+
yield FunctionInfo.from_fn(
|
|
86
|
+
fn=_retrieve,
|
|
87
|
+
input_schema=RetrieverInputSchema,
|
|
88
|
+
description=_get_description_from_config(config),
|
|
89
|
+
)
|
aiq/utils/__init__.py
ADDED
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import yaml
|
|
17
|
+
from pydantic import ValidationError
|
|
18
|
+
|
|
19
|
+
from ..exception_handlers.schemas import schema_exception_handler
|
|
20
|
+
from ..exception_handlers.schemas import yaml_exception_handler
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@schema_exception_handler
|
|
24
|
+
def validate_schema(metadata, Schema): # pylint: disable=invalid-name
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
return Schema(**metadata)
|
|
28
|
+
except ValidationError as e:
|
|
29
|
+
|
|
30
|
+
raise e
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@yaml_exception_handler
|
|
34
|
+
def validate_yaml(ctx, param, value): # pylint: disable=unused-argument
|
|
35
|
+
"""
|
|
36
|
+
Validate that the file is a valid YAML file
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
ctx: Click context
|
|
41
|
+
param: Click parameter
|
|
42
|
+
value: Path to YAML file
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
str: Path to valid YAML file
|
|
47
|
+
|
|
48
|
+
Raises
|
|
49
|
+
------
|
|
50
|
+
ValueError: If file is invalid or unreadable
|
|
51
|
+
"""
|
|
52
|
+
if value is None:
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
with open(value, 'r', encoding="utf-8") as f:
|
|
56
|
+
yaml.safe_load(f)
|
|
57
|
+
|
|
58
|
+
return value
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def is_debugger_attached() -> bool:
|
|
18
|
+
"""
|
|
19
|
+
Check if a debugger is attached to the current process.
|
|
20
|
+
|
|
21
|
+
Returns
|
|
22
|
+
-------
|
|
23
|
+
bool
|
|
24
|
+
True if a debugger is attached, False otherwise
|
|
25
|
+
"""
|
|
26
|
+
import sys
|
|
27
|
+
|
|
28
|
+
if "debugpy" in sys.modules:
|
|
29
|
+
|
|
30
|
+
import debugpy
|
|
31
|
+
|
|
32
|
+
return debugpy.is_client_connected()
|
|
33
|
+
|
|
34
|
+
trace_func = sys.gettrace()
|
|
35
|
+
|
|
36
|
+
# The presence of a trace function and pydevd means a debugger is attached
|
|
37
|
+
if (trace_func is not None):
|
|
38
|
+
trace_module = getattr(trace_func, "__module__", None)
|
|
39
|
+
|
|
40
|
+
if (trace_module is not None and trace_module.find("pydevd") != -1):
|
|
41
|
+
return True
|
|
42
|
+
|
|
43
|
+
return False
|
|
File without changes
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
import yaml
|
|
19
|
+
from pydantic import ValidationError
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def schema_exception_handler(func, **kwargs): # pylint: disable=unused-argument
|
|
25
|
+
"""
|
|
26
|
+
A decorator that handles `ValidationError` exceptions for schema validation functions.
|
|
27
|
+
|
|
28
|
+
This decorator wraps a function that performs schema validation using Pydantic.
|
|
29
|
+
If a `ValidationError` is raised, it logs detailed error messages and raises a `ValueError` with the combined error
|
|
30
|
+
messages.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
func : callable
|
|
35
|
+
The function to be decorated. This function is expected to perform schema validation.
|
|
36
|
+
|
|
37
|
+
kwargs : dict
|
|
38
|
+
Additional keyword arguments to be passed to the function.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
callable
|
|
43
|
+
The wrapped function that executes `func` with exception handling.
|
|
44
|
+
|
|
45
|
+
Raises
|
|
46
|
+
------
|
|
47
|
+
ValueError
|
|
48
|
+
If a `ValidationError` is caught, this decorator logs the error details and raises a `ValueError` with the
|
|
49
|
+
combined error messages.
|
|
50
|
+
|
|
51
|
+
Notes
|
|
52
|
+
-----
|
|
53
|
+
This decorator is particularly useful for functions that validate configurations or data models,
|
|
54
|
+
ensuring that any validation errors are logged and communicated clearly.
|
|
55
|
+
|
|
56
|
+
Examples
|
|
57
|
+
--------
|
|
58
|
+
>>> @schema_exception_handler
|
|
59
|
+
... def validate_config(config_data):
|
|
60
|
+
... schema = MySchema(**config_data)
|
|
61
|
+
... return schema
|
|
62
|
+
...
|
|
63
|
+
>>> try:
|
|
64
|
+
... validate_config(invalid_config)
|
|
65
|
+
... except ValueError as e:
|
|
66
|
+
... logger.error("Caught error: %s", e)
|
|
67
|
+
Caught error: Invalid configuration: field1: value is not a valid integer; field2: field required
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def inner_function(*args, **kwargs):
|
|
71
|
+
try:
|
|
72
|
+
return func(*args, **kwargs)
|
|
73
|
+
except ValidationError as e:
|
|
74
|
+
error_messages = "; ".join([f"{error['loc'][0]}: {error['msg']}" for error in e.errors()])
|
|
75
|
+
log_error_message = f"Invalid configuration: {error_messages}"
|
|
76
|
+
logger.error(log_error_message)
|
|
77
|
+
raise ValueError(log_error_message) from e
|
|
78
|
+
|
|
79
|
+
return inner_function
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def yaml_exception_handler(func):
|
|
83
|
+
"""
|
|
84
|
+
A decorator that handles YAML parsing exceptions.
|
|
85
|
+
|
|
86
|
+
This decorator wraps a function that performs YAML file operations.
|
|
87
|
+
If a YAML-related error occurs, it logs the error and raises a ValueError
|
|
88
|
+
with a clear error message.
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
callable
|
|
93
|
+
The wrapped function that executes `func` with YAML exception handling.
|
|
94
|
+
|
|
95
|
+
Raises
|
|
96
|
+
------
|
|
97
|
+
ValueError
|
|
98
|
+
If a YAML error is caught, with details about the parsing failure.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def inner_function(*args, **kwargs):
|
|
102
|
+
try:
|
|
103
|
+
return func(*args, **kwargs)
|
|
104
|
+
except yaml.YAMLError as e:
|
|
105
|
+
log_error_message = f"Invalid YAML configuration: {str(e)}"
|
|
106
|
+
logger.error(log_error_message)
|
|
107
|
+
raise ValueError(log_error_message) from e
|
|
108
|
+
|
|
109
|
+
except Exception as e:
|
|
110
|
+
log_error_message = f"Error reading YAML file: {str(e)}"
|
|
111
|
+
logger.error(log_error_message)
|
|
112
|
+
raise ValueError(log_error_message) from e
|
|
113
|
+
|
|
114
|
+
return inner_function
|
aiq/utils/io/__init__.py
ADDED
|
File without changes
|