aiqtoolkit 1.2.0.dev0__py3-none-any.whl → 1.2.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/base.py +170 -8
- aiq/agent/dual_node.py +1 -1
- aiq/agent/react_agent/agent.py +146 -112
- aiq/agent/react_agent/prompt.py +1 -6
- aiq/agent/react_agent/register.py +36 -35
- aiq/agent/rewoo_agent/agent.py +36 -35
- aiq/agent/rewoo_agent/register.py +2 -2
- aiq/agent/tool_calling_agent/agent.py +3 -7
- aiq/agent/tool_calling_agent/register.py +1 -1
- aiq/authentication/__init__.py +14 -0
- aiq/authentication/api_key/__init__.py +14 -0
- aiq/authentication/api_key/api_key_auth_provider.py +92 -0
- aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
- aiq/authentication/api_key/register.py +26 -0
- aiq/authentication/exceptions/__init__.py +14 -0
- aiq/authentication/exceptions/api_key_exceptions.py +38 -0
- aiq/authentication/exceptions/auth_code_grant_exceptions.py +86 -0
- aiq/authentication/exceptions/call_back_exceptions.py +38 -0
- aiq/authentication/exceptions/request_exceptions.py +54 -0
- aiq/authentication/http_basic_auth/__init__.py +0 -0
- aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- aiq/authentication/http_basic_auth/register.py +30 -0
- aiq/authentication/interfaces.py +93 -0
- aiq/authentication/oauth2/__init__.py +14 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- aiq/authentication/oauth2/register.py +25 -0
- aiq/authentication/register.py +21 -0
- aiq/builder/builder.py +64 -2
- aiq/builder/component_utils.py +16 -3
- aiq/builder/context.py +37 -0
- aiq/builder/eval_builder.py +43 -2
- aiq/builder/function.py +44 -12
- aiq/builder/function_base.py +1 -1
- aiq/builder/intermediate_step_manager.py +6 -8
- aiq/builder/user_interaction_manager.py +3 -0
- aiq/builder/workflow.py +23 -18
- aiq/builder/workflow_builder.py +421 -61
- aiq/cli/commands/info/list_mcp.py +103 -16
- aiq/cli/commands/sizing/__init__.py +14 -0
- aiq/cli/commands/sizing/calc.py +294 -0
- aiq/cli/commands/sizing/sizing.py +27 -0
- aiq/cli/commands/start.py +2 -1
- aiq/cli/entrypoint.py +2 -0
- aiq/cli/register_workflow.py +80 -0
- aiq/cli/type_registry.py +151 -30
- aiq/data_models/api_server.py +124 -12
- aiq/data_models/authentication.py +231 -0
- aiq/data_models/common.py +35 -7
- aiq/data_models/component.py +17 -9
- aiq/data_models/component_ref.py +33 -0
- aiq/data_models/config.py +60 -3
- aiq/data_models/dataset_handler.py +2 -1
- aiq/data_models/embedder.py +1 -0
- aiq/data_models/evaluate.py +23 -0
- aiq/data_models/function_dependencies.py +8 -0
- aiq/data_models/interactive.py +10 -1
- aiq/data_models/intermediate_step.py +38 -5
- aiq/data_models/its_strategy.py +30 -0
- aiq/data_models/llm.py +1 -0
- aiq/data_models/memory.py +1 -0
- aiq/data_models/object_store.py +44 -0
- aiq/data_models/profiler.py +1 -0
- aiq/data_models/retry_mixin.py +35 -0
- aiq/data_models/span.py +187 -0
- aiq/data_models/telemetry_exporter.py +2 -2
- aiq/embedder/nim_embedder.py +2 -1
- aiq/embedder/openai_embedder.py +2 -1
- aiq/eval/config.py +19 -1
- aiq/eval/dataset_handler/dataset_handler.py +87 -2
- aiq/eval/evaluate.py +208 -27
- aiq/eval/evaluator/base_evaluator.py +73 -0
- aiq/eval/evaluator/evaluator_model.py +1 -0
- aiq/eval/intermediate_step_adapter.py +11 -5
- aiq/eval/rag_evaluator/evaluate.py +55 -15
- aiq/eval/rag_evaluator/register.py +6 -1
- aiq/eval/remote_workflow.py +7 -2
- aiq/eval/runners/__init__.py +14 -0
- aiq/eval/runners/config.py +39 -0
- aiq/eval/runners/multi_eval_runner.py +54 -0
- aiq/eval/trajectory_evaluator/evaluate.py +22 -65
- aiq/eval/tunable_rag_evaluator/evaluate.py +150 -168
- aiq/eval/tunable_rag_evaluator/register.py +2 -0
- aiq/eval/usage_stats.py +41 -0
- aiq/eval/utils/output_uploader.py +10 -1
- aiq/eval/utils/weave_eval.py +184 -0
- aiq/experimental/__init__.py +0 -0
- aiq/experimental/decorators/__init__.py +0 -0
- aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
- aiq/experimental/inference_time_scaling/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/editing/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/editing/iterative_plan_refinement_editor.py +147 -0
- aiq/experimental/inference_time_scaling/editing/llm_as_a_judge_editor.py +204 -0
- aiq/experimental/inference_time_scaling/editing/motivation_aware_summarization.py +107 -0
- aiq/experimental/inference_time_scaling/functions/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/functions/execute_score_select_function.py +105 -0
- aiq/experimental/inference_time_scaling/functions/its_tool_orchestration_function.py +205 -0
- aiq/experimental/inference_time_scaling/functions/its_tool_wrapper_function.py +146 -0
- aiq/experimental/inference_time_scaling/functions/plan_select_execute_function.py +224 -0
- aiq/experimental/inference_time_scaling/models/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/models/editor_config.py +132 -0
- aiq/experimental/inference_time_scaling/models/its_item.py +48 -0
- aiq/experimental/inference_time_scaling/models/scoring_config.py +112 -0
- aiq/experimental/inference_time_scaling/models/search_config.py +120 -0
- aiq/experimental/inference_time_scaling/models/selection_config.py +154 -0
- aiq/experimental/inference_time_scaling/models/stage_enums.py +43 -0
- aiq/experimental/inference_time_scaling/models/strategy_base.py +66 -0
- aiq/experimental/inference_time_scaling/models/tool_use_config.py +41 -0
- aiq/experimental/inference_time_scaling/register.py +36 -0
- aiq/experimental/inference_time_scaling/scoring/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/scoring/llm_based_agent_scorer.py +168 -0
- aiq/experimental/inference_time_scaling/scoring/llm_based_plan_scorer.py +168 -0
- aiq/experimental/inference_time_scaling/scoring/motivation_aware_scorer.py +111 -0
- aiq/experimental/inference_time_scaling/search/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/search/multi_llm_planner.py +128 -0
- aiq/experimental/inference_time_scaling/search/multi_query_retrieval_search.py +122 -0
- aiq/experimental/inference_time_scaling/search/single_shot_multi_plan_planner.py +128 -0
- aiq/experimental/inference_time_scaling/selection/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/selection/best_of_n_selector.py +63 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_agent_output_selector.py +131 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_output_merging_selector.py +159 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_plan_selector.py +128 -0
- aiq/experimental/inference_time_scaling/selection/threshold_selector.py +58 -0
- aiq/front_ends/console/authentication_flow_handler.py +233 -0
- aiq/front_ends/console/console_front_end_plugin.py +11 -2
- aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +93 -9
- aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +14 -1
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +537 -52
- aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
- aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- aiq/front_ends/fastapi/job_store.py +47 -25
- aiq/front_ends/fastapi/main.py +2 -0
- aiq/front_ends/fastapi/message_handler.py +108 -89
- aiq/front_ends/fastapi/step_adaptor.py +2 -1
- aiq/llm/aws_bedrock_llm.py +57 -0
- aiq/llm/nim_llm.py +2 -1
- aiq/llm/openai_llm.py +3 -2
- aiq/llm/register.py +1 -0
- aiq/meta/pypi.md +12 -12
- aiq/object_store/__init__.py +20 -0
- aiq/object_store/in_memory_object_store.py +74 -0
- aiq/object_store/interfaces.py +84 -0
- aiq/object_store/models.py +36 -0
- aiq/object_store/register.py +20 -0
- aiq/observability/__init__.py +14 -0
- aiq/observability/exporter/__init__.py +14 -0
- aiq/observability/exporter/base_exporter.py +449 -0
- aiq/observability/exporter/exporter.py +78 -0
- aiq/observability/exporter/file_exporter.py +33 -0
- aiq/observability/exporter/processing_exporter.py +269 -0
- aiq/observability/exporter/raw_exporter.py +52 -0
- aiq/observability/exporter/span_exporter.py +264 -0
- aiq/observability/exporter_manager.py +335 -0
- aiq/observability/mixin/__init__.py +14 -0
- aiq/observability/mixin/batch_config_mixin.py +26 -0
- aiq/observability/mixin/collector_config_mixin.py +23 -0
- aiq/observability/mixin/file_mixin.py +288 -0
- aiq/observability/mixin/file_mode.py +23 -0
- aiq/observability/mixin/resource_conflict_mixin.py +134 -0
- aiq/observability/mixin/serialize_mixin.py +61 -0
- aiq/observability/mixin/type_introspection_mixin.py +183 -0
- aiq/observability/processor/__init__.py +14 -0
- aiq/observability/processor/batching_processor.py +316 -0
- aiq/observability/processor/intermediate_step_serializer.py +28 -0
- aiq/observability/processor/processor.py +68 -0
- aiq/observability/register.py +36 -39
- aiq/observability/utils/__init__.py +14 -0
- aiq/observability/utils/dict_utils.py +236 -0
- aiq/observability/utils/time_utils.py +31 -0
- aiq/profiler/calc/__init__.py +14 -0
- aiq/profiler/calc/calc_runner.py +623 -0
- aiq/profiler/calc/calculations.py +288 -0
- aiq/profiler/calc/data_models.py +176 -0
- aiq/profiler/calc/plot.py +345 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +22 -10
- aiq/profiler/data_models.py +24 -0
- aiq/profiler/inference_metrics_model.py +3 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +8 -0
- aiq/profiler/inference_optimization/data_models.py +2 -2
- aiq/profiler/inference_optimization/llm_metrics.py +2 -2
- aiq/profiler/profile_runner.py +61 -21
- aiq/runtime/loader.py +9 -3
- aiq/runtime/runner.py +23 -9
- aiq/runtime/session.py +25 -7
- aiq/runtime/user_metadata.py +2 -3
- aiq/tool/chat_completion.py +74 -0
- aiq/tool/code_execution/README.md +152 -0
- aiq/tool/code_execution/code_sandbox.py +151 -72
- aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +139 -24
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +3 -1
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +27 -2
- aiq/tool/code_execution/register.py +7 -3
- aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
- aiq/tool/mcp/exceptions.py +142 -0
- aiq/tool/mcp/mcp_client.py +41 -6
- aiq/tool/mcp/mcp_tool.py +3 -2
- aiq/tool/register.py +1 -0
- aiq/tool/server_tools.py +6 -3
- aiq/utils/exception_handlers/automatic_retries.py +289 -0
- aiq/utils/exception_handlers/mcp.py +211 -0
- aiq/utils/io/model_processing.py +28 -0
- aiq/utils/log_utils.py +37 -0
- aiq/utils/string_utils.py +38 -0
- aiq/utils/type_converter.py +18 -2
- aiq/utils/type_utils.py +87 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/METADATA +53 -21
- aiqtoolkit-1.2.0rc2.dist-info/RECORD +436 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/WHEEL +1 -1
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/entry_points.txt +3 -0
- aiq/front_ends/fastapi/websocket.py +0 -148
- aiq/observability/async_otel_listener.py +0 -429
- aiqtoolkit-1.2.0.dev0.dist-info/RECORD +0 -316
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
import re
|
|
19
|
+
|
|
20
|
+
from aiq.builder.builder import Builder
|
|
21
|
+
from aiq.builder.framework_enum import LLMFrameworkEnum
|
|
22
|
+
from aiq.cli.register_workflow import register_its_strategy
|
|
23
|
+
from aiq.data_models.its_strategy import ITSStrategyBaseConfig
|
|
24
|
+
from aiq.experimental.inference_time_scaling.models.editor_config import LLMAsAJudgeEditorConfig
|
|
25
|
+
from aiq.experimental.inference_time_scaling.models.its_item import ITSItem
|
|
26
|
+
from aiq.experimental.inference_time_scaling.models.stage_enums import PipelineTypeEnum
|
|
27
|
+
from aiq.experimental.inference_time_scaling.models.stage_enums import StageTypeEnum
|
|
28
|
+
from aiq.experimental.inference_time_scaling.models.strategy_base import StrategyBase
|
|
29
|
+
from aiq.utils.io.model_processing import remove_r1_think_tags
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class LLMAsAJudgeEditor(StrategyBase):
|
|
35
|
+
"""
|
|
36
|
+
Given a list of PlanningItems, uses a feedback LLM to generate feedback on each plan
|
|
37
|
+
Then edits the plan based on feedback.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, config: ITSStrategyBaseConfig) -> None:
|
|
41
|
+
super().__init__(config)
|
|
42
|
+
self.feedback_llm = None
|
|
43
|
+
self.editing_llm = None
|
|
44
|
+
|
|
45
|
+
async def build_components(self, builder: Builder) -> None:
|
|
46
|
+
"""
|
|
47
|
+
Build the components required for the editor.
|
|
48
|
+
"""
|
|
49
|
+
# Get the feedback LLM
|
|
50
|
+
self.feedback_llm = await builder.get_llm(self.config.feedback_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
51
|
+
|
|
52
|
+
self.editing_llm = await builder.get_llm(self.config.editing_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
53
|
+
|
|
54
|
+
def supported_pipeline_types(self) -> [PipelineTypeEnum]:
|
|
55
|
+
return [PipelineTypeEnum.PLANNING]
|
|
56
|
+
|
|
57
|
+
def stage_type(self) -> StageTypeEnum:
|
|
58
|
+
return StageTypeEnum.EDITING
|
|
59
|
+
|
|
60
|
+
async def generate_feedback(self, llm, template, context: str, prompt: str, item: ITSItem) -> ITSItem:
|
|
61
|
+
"""
|
|
62
|
+
Helper function to generate feedback for a given planning item using the provided prompt.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
prompt = await template.ainvoke(
|
|
66
|
+
input={
|
|
67
|
+
"context": context,
|
|
68
|
+
"original_prompt": prompt, # Original prompt used to generate the plans
|
|
69
|
+
"plan": item.plan,
|
|
70
|
+
"num_feedback": self.config.num_feedback
|
|
71
|
+
})
|
|
72
|
+
|
|
73
|
+
feedback_result = await llm.ainvoke(prompt.to_string())
|
|
74
|
+
if not feedback_result:
|
|
75
|
+
logger.warning(f"No feedback generated for plan: {item.plan}.")
|
|
76
|
+
return item
|
|
77
|
+
|
|
78
|
+
# Update the planning item with the generated feedback
|
|
79
|
+
cleaned = remove_r1_think_tags(
|
|
80
|
+
feedback_result.content if hasattr(feedback_result, 'content') else str(feedback_result))
|
|
81
|
+
|
|
82
|
+
# Feedback is the string following 'FEEDBACK:'. Use Regex to extract
|
|
83
|
+
cleaned = re.sub(r'(?i)^\s*FEEDBACK:\s*', '', cleaned).strip()
|
|
84
|
+
if not cleaned:
|
|
85
|
+
logger.warning(f"Feedback was empty for plan: {item.plan}.")
|
|
86
|
+
return item
|
|
87
|
+
|
|
88
|
+
item.feedback = cleaned # Set the feedback in the ITSItem
|
|
89
|
+
|
|
90
|
+
return item
|
|
91
|
+
|
|
92
|
+
async def edit_plan(self, llm, template, context: str, prompt: str, item: ITSItem) -> ITSItem:
|
|
93
|
+
"""
|
|
94
|
+
Helper function to edit a plan based on feedback using the provided prompt.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
if not item.feedback:
|
|
98
|
+
logger.warning(f"No feedback available for plan: {item.plan}. Cannot edit.")
|
|
99
|
+
return item
|
|
100
|
+
|
|
101
|
+
prompt = await template.ainvoke(
|
|
102
|
+
input={
|
|
103
|
+
"context": context,
|
|
104
|
+
"original_prompt": prompt, # Original prompt used to generate the plans
|
|
105
|
+
"plan": item.plan,
|
|
106
|
+
"feedback": item.feedback
|
|
107
|
+
})
|
|
108
|
+
|
|
109
|
+
editing_result = await llm.ainvoke(prompt.to_string())
|
|
110
|
+
if not editing_result:
|
|
111
|
+
logger.warning(f"No editing result generated for plan: {item.plan}.")
|
|
112
|
+
return item
|
|
113
|
+
|
|
114
|
+
# Update the planning item with the edited plan
|
|
115
|
+
cleaned = remove_r1_think_tags(
|
|
116
|
+
editing_result.content if hasattr(editing_result, 'content') else str(editing_result))
|
|
117
|
+
|
|
118
|
+
# Plan is the string following 'EDITED PLAN:'. Use Regex to extract
|
|
119
|
+
cleaned = re.sub(r'(?i)^\s*EDITED PLAN:\s*', '', cleaned).strip()
|
|
120
|
+
if not cleaned:
|
|
121
|
+
logger.warning(f"Edited plan was empty for plan: {item.plan}. Returning original.")
|
|
122
|
+
return item
|
|
123
|
+
|
|
124
|
+
# Update the plan in the PlanningItem
|
|
125
|
+
item.plan = cleaned
|
|
126
|
+
|
|
127
|
+
return item
|
|
128
|
+
|
|
129
|
+
async def ainvoke(self,
|
|
130
|
+
items: list[ITSItem],
|
|
131
|
+
original_prompt: str | None = None,
|
|
132
|
+
agent_context: str | None = None,
|
|
133
|
+
**kwargs) -> list[ITSItem]:
|
|
134
|
+
"""
|
|
135
|
+
Edit the provided planning items using a feedback LLM.
|
|
136
|
+
"""
|
|
137
|
+
from langchain_core.language_models import BaseChatModel
|
|
138
|
+
from langchain_core.prompts import PromptTemplate
|
|
139
|
+
|
|
140
|
+
# assert self.config.feedback_llm is a BaseChatModel
|
|
141
|
+
if not isinstance(self.feedback_llm, BaseChatModel):
|
|
142
|
+
raise ValueError("The `feedback_llm` must be an instance of `BaseChatModel`.")
|
|
143
|
+
|
|
144
|
+
# assert self.config.editing_llm is a BaseChatModel
|
|
145
|
+
if not isinstance(self.editing_llm, BaseChatModel):
|
|
146
|
+
raise ValueError("The `editing_llm` must be an instance of `BaseChatModel`.")
|
|
147
|
+
|
|
148
|
+
feedback_model: BaseChatModel = self.feedback_llm
|
|
149
|
+
editing_model: BaseChatModel = self.editing_llm
|
|
150
|
+
|
|
151
|
+
feedback_template = PromptTemplate(template=self.config.feedback_template,
|
|
152
|
+
input_variables=["context", "original_prompt", "plan", "num_feedback"],
|
|
153
|
+
validate_template=True)
|
|
154
|
+
|
|
155
|
+
editing_template = PromptTemplate(template=self.config.editor_template,
|
|
156
|
+
input_variables=["context", "original_prompt", "plan", "feedback"],
|
|
157
|
+
validate_template=True)
|
|
158
|
+
|
|
159
|
+
# Generate feedback for each planning item concurrently
|
|
160
|
+
feedback_tasks = [
|
|
161
|
+
self.generate_feedback(
|
|
162
|
+
llm=feedback_model,
|
|
163
|
+
template=feedback_template,
|
|
164
|
+
context=agent_context,
|
|
165
|
+
prompt=original_prompt, # Original prompt used to generate the plans
|
|
166
|
+
item=item) for item in items
|
|
167
|
+
]
|
|
168
|
+
# Run the feedback tasks concurrently and gather results
|
|
169
|
+
planning_items_with_feedback = await asyncio.gather(*feedback_tasks)
|
|
170
|
+
|
|
171
|
+
if not planning_items_with_feedback:
|
|
172
|
+
raise ValueError("No feedback was generated for the planning items. Please check the LLM response.")
|
|
173
|
+
|
|
174
|
+
logger.info("Generated feedback for %d plans.", len(planning_items_with_feedback))
|
|
175
|
+
|
|
176
|
+
# Now edit each planning item based on the feedback concurrently
|
|
177
|
+
editing_tasks = [
|
|
178
|
+
self.edit_plan(
|
|
179
|
+
llm=editing_model,
|
|
180
|
+
template=editing_template,
|
|
181
|
+
context=agent_context,
|
|
182
|
+
prompt=original_prompt, # Original prompt used to generate the plans
|
|
183
|
+
item=item) for item in planning_items_with_feedback
|
|
184
|
+
]
|
|
185
|
+
# Run the editing tasks concurrently and gather results
|
|
186
|
+
edited_planning_items = await asyncio.gather(*editing_tasks)
|
|
187
|
+
|
|
188
|
+
if not edited_planning_items:
|
|
189
|
+
raise ValueError("No plans were edited. Please check the LLM response.")
|
|
190
|
+
|
|
191
|
+
logger.info("Edited %d plans based on feedback.", len(edited_planning_items))
|
|
192
|
+
return edited_planning_items
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@register_its_strategy(config_type=LLMAsAJudgeEditorConfig)
|
|
196
|
+
async def register_llm_as_a_judge_editor(config: ITSStrategyBaseConfig, builder: Builder):
|
|
197
|
+
"""
|
|
198
|
+
Register the LLMAsAJudgeEditor strategy with the provided configuration and builder.
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
editor = LLMAsAJudgeEditor(config)
|
|
202
|
+
await editor.build_components(builder)
|
|
203
|
+
|
|
204
|
+
yield editor
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from aiq.builder.builder import Builder
|
|
19
|
+
from aiq.builder.framework_enum import LLMFrameworkEnum
|
|
20
|
+
from aiq.cli.register_workflow import register_its_strategy
|
|
21
|
+
from aiq.experimental.inference_time_scaling.models.editor_config import MotivationAwareSummarizationConfig
|
|
22
|
+
from aiq.experimental.inference_time_scaling.models.its_item import ITSItem
|
|
23
|
+
from aiq.experimental.inference_time_scaling.models.stage_enums import PipelineTypeEnum
|
|
24
|
+
from aiq.experimental.inference_time_scaling.models.stage_enums import StageTypeEnum
|
|
25
|
+
from aiq.experimental.inference_time_scaling.models.strategy_base import StrategyBase
|
|
26
|
+
from aiq.utils.io.model_processing import remove_r1_think_tags
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MotivationAwareSummarization(StrategyBase):
|
|
32
|
+
"""
|
|
33
|
+
A strategy that, for each incoming ITSItem, summarizes the output based on input
|
|
34
|
+
and motivation.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, config: MotivationAwareSummarizationConfig) -> None:
|
|
38
|
+
super().__init__(config)
|
|
39
|
+
self.config = config
|
|
40
|
+
self.llm_bound = None
|
|
41
|
+
|
|
42
|
+
async def build_components(self, builder: Builder) -> None:
|
|
43
|
+
"""
|
|
44
|
+
Binds each LLMRef in self.config.llms to an actual LLM client.
|
|
45
|
+
"""
|
|
46
|
+
bound_llm = await builder.get_llm(self.config.editor_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
47
|
+
self.llm_bound = bound_llm
|
|
48
|
+
|
|
49
|
+
def supported_pipeline_types(self) -> list[PipelineTypeEnum]:
|
|
50
|
+
return [PipelineTypeEnum.TOOL_USE]
|
|
51
|
+
|
|
52
|
+
def stage_type(self) -> StageTypeEnum:
|
|
53
|
+
return StageTypeEnum.EDITING
|
|
54
|
+
|
|
55
|
+
async def ainvoke(self,
|
|
56
|
+
items: list[ITSItem],
|
|
57
|
+
original_prompt: str | None = None,
|
|
58
|
+
agent_context: str | None = None,
|
|
59
|
+
**kwargs) -> list[ITSItem]:
|
|
60
|
+
"""
|
|
61
|
+
For each ITSItem, rewrite the 'input' using each LLM to create a new perspective.
|
|
62
|
+
The new ITSItems' 'output' field will store the newly generated query.
|
|
63
|
+
"""
|
|
64
|
+
try:
|
|
65
|
+
from langchain_core.prompts import PromptTemplate
|
|
66
|
+
except ImportError:
|
|
67
|
+
raise ImportError("langchain-core is required for MultiQueryRetrievalSearch. "
|
|
68
|
+
"Install aiqtoolkit-langchain or similar.")
|
|
69
|
+
|
|
70
|
+
new_its_items: list[ITSItem] = []
|
|
71
|
+
|
|
72
|
+
# Create a single PromptTemplate object for rewriting the query
|
|
73
|
+
template_vars = ["task", "motivation", "output"]
|
|
74
|
+
query_template = PromptTemplate(template=self.config.editor_template,
|
|
75
|
+
input_variables=template_vars,
|
|
76
|
+
validate_template=True)
|
|
77
|
+
|
|
78
|
+
for item in items:
|
|
79
|
+
original_task = str(item.input) or ""
|
|
80
|
+
motivation = str(item.metadata) if item.metadata else ""
|
|
81
|
+
output = str(item.output) if item.output else ""
|
|
82
|
+
|
|
83
|
+
prompt = await (query_template.ainvoke(input={
|
|
84
|
+
"task": original_task, "motivation": motivation, "output": output
|
|
85
|
+
}))
|
|
86
|
+
|
|
87
|
+
llm_response = await self.llm_bound.ainvoke(prompt.to_string())
|
|
88
|
+
llm_response = remove_r1_think_tags(llm_response.content)
|
|
89
|
+
|
|
90
|
+
logger.info("LLM response from summarization: %s", llm_response)
|
|
91
|
+
|
|
92
|
+
new_its_items.append(
|
|
93
|
+
ITSItem(
|
|
94
|
+
input=item.input,
|
|
95
|
+
output=remove_r1_think_tags(llm_response),
|
|
96
|
+
metadata=item.metadata,
|
|
97
|
+
name=item.name, # keep the original tool name
|
|
98
|
+
))
|
|
99
|
+
|
|
100
|
+
return new_its_items
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@register_its_strategy(config_type=MotivationAwareSummarizationConfig)
|
|
104
|
+
async def register_multi_query_retrieval_search(config: MotivationAwareSummarizationConfig, builder: Builder):
|
|
105
|
+
strategy = MotivationAwareSummarization(config)
|
|
106
|
+
await strategy.build_components(builder)
|
|
107
|
+
yield strategy
|
|
File without changes
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from aiq.builder.builder import Builder
|
|
21
|
+
from aiq.builder.function import Function
|
|
22
|
+
from aiq.builder.function_info import FunctionInfo
|
|
23
|
+
from aiq.cli.register_workflow import register_function
|
|
24
|
+
from aiq.data_models.component_ref import FunctionRef
|
|
25
|
+
from aiq.data_models.component_ref import ITSStrategyRef
|
|
26
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
27
|
+
from aiq.experimental.inference_time_scaling.models.its_item import ITSItem
|
|
28
|
+
from aiq.experimental.inference_time_scaling.models.stage_enums import PipelineTypeEnum
|
|
29
|
+
from aiq.experimental.inference_time_scaling.models.stage_enums import StageTypeEnum
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ExecuteScoreSelectFunctionConfig(FunctionBaseConfig, name="execute_score_select_function"):
|
|
35
|
+
scorer: ITSStrategyRef | None = Field(description="Strategy to score the output of the function", default=None)
|
|
36
|
+
selector: ITSStrategyRef = Field(description="Strategy to select the best output of the function")
|
|
37
|
+
augmented_fn: FunctionRef = Field(description="Function that will be executed")
|
|
38
|
+
|
|
39
|
+
num_executions: int = Field(3, description="Number of times to execute the function")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@register_function(config_type=ExecuteScoreSelectFunctionConfig)
|
|
43
|
+
async def execute_score_select_function(config: ExecuteScoreSelectFunctionConfig, builder: Builder):
|
|
44
|
+
import asyncio
|
|
45
|
+
import warnings
|
|
46
|
+
|
|
47
|
+
from pydantic import BaseModel
|
|
48
|
+
|
|
49
|
+
executable_fn: Function = builder.get_function(name=config.augmented_fn)
|
|
50
|
+
|
|
51
|
+
if config.scorer:
|
|
52
|
+
scorer = await builder.get_its_strategy(strategy_name=config.scorer,
|
|
53
|
+
pipeline_type=PipelineTypeEnum.AGENT_EXECUTION,
|
|
54
|
+
stage_type=StageTypeEnum.SCORING)
|
|
55
|
+
else:
|
|
56
|
+
scorer = None
|
|
57
|
+
|
|
58
|
+
selector = await builder.get_its_strategy(strategy_name=config.selector,
|
|
59
|
+
pipeline_type=PipelineTypeEnum.AGENT_EXECUTION,
|
|
60
|
+
stage_type=StageTypeEnum.SELECTION)
|
|
61
|
+
|
|
62
|
+
if executable_fn.has_streaming_output:
|
|
63
|
+
warnings.warn("Streaming output is not supported for this function. "
|
|
64
|
+
"The function will be executed in non-streaming mode.")
|
|
65
|
+
|
|
66
|
+
def convert_to_str(arg):
|
|
67
|
+
if isinstance(arg, BaseModel):
|
|
68
|
+
return str(arg.model_dump())
|
|
69
|
+
return str(arg)
|
|
70
|
+
|
|
71
|
+
async def execute_fn(input_msg: executable_fn.input_type) -> executable_fn.single_output_type:
|
|
72
|
+
|
|
73
|
+
logger.info("Executing function %d times", config.num_executions)
|
|
74
|
+
tasks = [executable_fn.ainvoke(input_msg) for _ in range(config.num_executions)]
|
|
75
|
+
results = await asyncio.gather(*tasks)
|
|
76
|
+
|
|
77
|
+
input_str = convert_to_str(input_msg)
|
|
78
|
+
function_outputs = [convert_to_str(out) for out in results]
|
|
79
|
+
its_items = [ITSItem(
|
|
80
|
+
input=input_str,
|
|
81
|
+
output=out,
|
|
82
|
+
) for out in function_outputs]
|
|
83
|
+
|
|
84
|
+
if scorer:
|
|
85
|
+
logger.info("Beginning scoring")
|
|
86
|
+
its_items = await scorer.ainvoke(items=its_items)
|
|
87
|
+
|
|
88
|
+
logger.info("Beginning selection")
|
|
89
|
+
selected_item = (await selector.ainvoke(items=its_items, original_prompt=its_items[0].input))[0]
|
|
90
|
+
|
|
91
|
+
# Find the index of selected item in its_items by matching the output
|
|
92
|
+
selected_output = selected_item.output
|
|
93
|
+
selected_index = -1
|
|
94
|
+
for i, item in enumerate(its_items):
|
|
95
|
+
if item.output == selected_output:
|
|
96
|
+
selected_index = i
|
|
97
|
+
break
|
|
98
|
+
|
|
99
|
+
return results[selected_index] if selected_index != -1 else selected_output
|
|
100
|
+
|
|
101
|
+
yield FunctionInfo.from_fn(
|
|
102
|
+
fn=execute_fn,
|
|
103
|
+
description=("This function executes a given function multiple times, scores the outputs, "
|
|
104
|
+
"and selects the best output based on the specified scoring and selection strategies."),
|
|
105
|
+
)
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from aiq.builder.builder import Builder
|
|
22
|
+
from aiq.builder.framework_enum import LLMFrameworkEnum
|
|
23
|
+
from aiq.builder.function_info import FunctionInfo
|
|
24
|
+
from aiq.cli.register_workflow import register_function
|
|
25
|
+
from aiq.data_models.component_ref import FunctionRef
|
|
26
|
+
from aiq.data_models.component_ref import ITSStrategyRef
|
|
27
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
28
|
+
from aiq.experimental.inference_time_scaling.models.its_item import ITSItem
|
|
29
|
+
from aiq.experimental.inference_time_scaling.models.stage_enums import PipelineTypeEnum
|
|
30
|
+
from aiq.experimental.inference_time_scaling.models.stage_enums import StageTypeEnum
|
|
31
|
+
from aiq.experimental.inference_time_scaling.models.tool_use_config import ToolUseInputSchema
|
|
32
|
+
from aiq.experimental.inference_time_scaling.models.tool_use_config import ToolUselist
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ITSToolOrchestrationFunctionConfig(FunctionBaseConfig, name="its_tool_orchestration"):
|
|
38
|
+
"""
|
|
39
|
+
Configuration for the ITSToolOrchestrationFunction, which is used to orchestrate multiple functions.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
augmented_fns: list[FunctionRef] = Field(
|
|
43
|
+
description="list of FunctionRefs for the functions to be orchestrated. Must be wrapped in `its_tool_wrapper`.")
|
|
44
|
+
|
|
45
|
+
search_strategy: ITSStrategyRef | None = Field(
|
|
46
|
+
description="The ITS search strategy to use for orchestrating invocation of the functions."
|
|
47
|
+
" If None, no search will be performed.",
|
|
48
|
+
default=None,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
editing_strategy: ITSStrategyRef | None = Field(
|
|
52
|
+
default=None,
|
|
53
|
+
description="The ITS editing strategy to use for orchestrating invocation of the functions. "
|
|
54
|
+
"If None, no editing will be performed.",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
scoring_strategy: ITSStrategyRef | None = Field(
|
|
58
|
+
default=None,
|
|
59
|
+
description="The ITS scoring strategy to use for orchestrating invocation of the functions. "
|
|
60
|
+
"If None, no scoring will be performed.",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
selection_strategy: ITSStrategyRef = Field(
|
|
64
|
+
description="The ITS selection strategy to use for orchestrating invocation of the functions.")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@register_function(config_type=ITSToolOrchestrationFunctionConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
68
|
+
async def register_its_tool_orchestration_function(
|
|
69
|
+
config: ITSToolOrchestrationFunctionConfig,
|
|
70
|
+
builder: Builder,
|
|
71
|
+
):
|
|
72
|
+
"""
|
|
73
|
+
Registers an ITS-based orchestration function that:
|
|
74
|
+
1. Instantiates all relevant strategies (search, editing, scoring, selection).
|
|
75
|
+
2. Accepts a ToolUselist, converts each item to an ITSItem, optionally runs search/editing.
|
|
76
|
+
3. Calls the correct augmented_fn per item using name=tool name.
|
|
77
|
+
4. If configured, runs scoring and selection on the result.
|
|
78
|
+
5. Returns a new ToolUselist with each output set.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
# 1) Gather references to all augmented (wrapped) functions
|
|
82
|
+
function_map = {}
|
|
83
|
+
for fn_ref in config.augmented_fns:
|
|
84
|
+
# Retrieve the actual function from the builder
|
|
85
|
+
fn_obj = builder.get_function(fn_ref)
|
|
86
|
+
function_map[fn_ref] = fn_obj
|
|
87
|
+
|
|
88
|
+
# 2) Instantiate search, editing, scoring, selection strategies (if any)
|
|
89
|
+
search = None
|
|
90
|
+
if config.search_strategy is not None:
|
|
91
|
+
search = await builder.get_its_strategy(
|
|
92
|
+
strategy_name=config.search_strategy,
|
|
93
|
+
pipeline_type=PipelineTypeEnum.TOOL_USE,
|
|
94
|
+
stage_type=StageTypeEnum.SEARCH,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
editing = None
|
|
98
|
+
if config.editing_strategy is not None:
|
|
99
|
+
editing = await builder.get_its_strategy(
|
|
100
|
+
strategy_name=config.editing_strategy,
|
|
101
|
+
pipeline_type=PipelineTypeEnum.TOOL_USE,
|
|
102
|
+
stage_type=StageTypeEnum.EDITING,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
scoring = None
|
|
106
|
+
if config.scoring_strategy is not None:
|
|
107
|
+
scoring = await builder.get_its_strategy(
|
|
108
|
+
strategy_name=config.scoring_strategy,
|
|
109
|
+
pipeline_type=PipelineTypeEnum.TOOL_USE,
|
|
110
|
+
stage_type=StageTypeEnum.SCORING,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
selection = await builder.get_its_strategy(
|
|
114
|
+
strategy_name=config.selection_strategy,
|
|
115
|
+
pipeline_type=PipelineTypeEnum.TOOL_USE,
|
|
116
|
+
stage_type=StageTypeEnum.SELECTION,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
fn_description = ("\n".join(f"- **{fn_ref}**: {function_map[fn_ref].description or 'No description provided.'}"
|
|
120
|
+
for fn_ref in config.augmented_fns))
|
|
121
|
+
|
|
122
|
+
# 3) Create the inner function to handle single (non-streaming) calls.
|
|
123
|
+
async def single_inner(tool_list: ToolUselist) -> ToolUselist:
|
|
124
|
+
"""
|
|
125
|
+
Orchestrates multiple tool usages, optionally using search/editing/scoring/selection steps.
|
|
126
|
+
"""
|
|
127
|
+
# Convert each ToolUseInputSchema to ITSItem
|
|
128
|
+
its_items = []
|
|
129
|
+
for t in tool_list.tools:
|
|
130
|
+
item = ITSItem(
|
|
131
|
+
input=t.task_description, # The user "task"
|
|
132
|
+
output=None,
|
|
133
|
+
name=t.tool_name, # The "tool name"
|
|
134
|
+
metadata=t.motivation, # The "justification"
|
|
135
|
+
)
|
|
136
|
+
its_items.append(item)
|
|
137
|
+
|
|
138
|
+
# Run search strategy if present
|
|
139
|
+
if search is not None:
|
|
140
|
+
its_items = await search.ainvoke(its_items)
|
|
141
|
+
|
|
142
|
+
logger.info("ITS orchestration function: %d items after search", len(its_items))
|
|
143
|
+
|
|
144
|
+
# Invoke the correct augmented function for each item concurrently
|
|
145
|
+
# Helper coroutine to invoke a tool function and capture result or error
|
|
146
|
+
async def _invoke_tool(item: ITSItem, fn):
|
|
147
|
+
try:
|
|
148
|
+
result = await fn.acall_invoke(item.output)
|
|
149
|
+
return item, result, None
|
|
150
|
+
except Exception as e:
|
|
151
|
+
logger.error(f"Error invoking function '{item.name}': {e}")
|
|
152
|
+
return item, None, str(e)
|
|
153
|
+
|
|
154
|
+
tasks = []
|
|
155
|
+
for item in its_items:
|
|
156
|
+
if item.name not in function_map:
|
|
157
|
+
logger.error(f"Function '{item.name}' not found in function map.")
|
|
158
|
+
item.output = f"Error: Function '{item.name}' not found in function map. Check your input"
|
|
159
|
+
else:
|
|
160
|
+
fn = function_map[item.name]
|
|
161
|
+
tasks.append(_invoke_tool(item, fn))
|
|
162
|
+
|
|
163
|
+
# Await all tasks and assign outputs
|
|
164
|
+
if tasks:
|
|
165
|
+
results = await asyncio.gather(*tasks)
|
|
166
|
+
for item, result, error in results:
|
|
167
|
+
if error:
|
|
168
|
+
item.output = f"Error invoking function '{item.name}': {error}"
|
|
169
|
+
else:
|
|
170
|
+
item.output = result
|
|
171
|
+
|
|
172
|
+
if editing:
|
|
173
|
+
its_items = await editing.ainvoke(its_items)
|
|
174
|
+
|
|
175
|
+
# Run scoring strategy if present
|
|
176
|
+
if scoring is not None:
|
|
177
|
+
its_items = await scoring.ainvoke(its_items)
|
|
178
|
+
|
|
179
|
+
# Run selection strategy
|
|
180
|
+
if selection is not None:
|
|
181
|
+
its_items = await selection.ainvoke(its_items)
|
|
182
|
+
|
|
183
|
+
logger.info("ITS orchestration function: %d items after selection", len(its_items))
|
|
184
|
+
|
|
185
|
+
# Convert final results from ITSItems back to a ToolUselist
|
|
186
|
+
final_list = ToolUselist(tools=[])
|
|
187
|
+
for item in its_items:
|
|
188
|
+
# Compose a new ToolUseInputSchema with final output
|
|
189
|
+
new_tool = ToolUseInputSchema(
|
|
190
|
+
tool_name=item.name,
|
|
191
|
+
task_description=str(item.input),
|
|
192
|
+
motivation=item.metadata if item.metadata else None,
|
|
193
|
+
output=str(item.output) if item.output is not None else None,
|
|
194
|
+
)
|
|
195
|
+
final_list.tools.append(new_tool)
|
|
196
|
+
|
|
197
|
+
return final_list
|
|
198
|
+
|
|
199
|
+
# 4) Return the function info (only a single_fn is needed; no streaming)
|
|
200
|
+
yield FunctionInfo.create(
|
|
201
|
+
single_fn=single_inner,
|
|
202
|
+
stream_fn=None, # No streaming required
|
|
203
|
+
input_schema=ToolUselist,
|
|
204
|
+
single_output_schema=ToolUselist,
|
|
205
|
+
description=fn_description)
|