aiqtoolkit 1.2.0rc2__py3-none-any.whl → 1.2.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/base.py +8 -7
- aiq/agent/react_agent/agent.py +2 -3
- aiq/agent/react_agent/register.py +1 -1
- aiq/agent/reasoning_agent/reasoning_agent.py +2 -1
- aiq/agent/tool_calling_agent/register.py +2 -1
- aiq/authentication/api_key/api_key_auth_provider.py +6 -2
- aiq/builder/function.py +21 -6
- aiq/builder/function_base.py +6 -2
- aiq/cli/commands/sizing/calc.py +6 -3
- aiq/cli/commands/start.py +0 -5
- aiq/cli/commands/uninstall.py +2 -4
- aiq/data_models/api_server.py +6 -12
- aiq/data_models/component_ref.py +1 -1
- aiq/data_models/discovery_metadata.py +62 -13
- aiq/front_ends/console/console_front_end_plugin.py +2 -22
- aiq/front_ends/simple_base/simple_front_end_plugin_base.py +4 -2
- aiq/object_store/in_memory_object_store.py +18 -16
- aiq/observability/exporter/processing_exporter.py +99 -46
- aiq/observability/exporter/span_exporter.py +1 -0
- aiq/observability/processor/batching_processor.py +52 -59
- aiq/observability/processor/callback_processor.py +42 -0
- aiq/observability/processor/processor.py +4 -1
- aiq/profiler/calc/calc_runner.py +5 -1
- aiq/profiler/calc/data_models.py +18 -6
- aiq/registry_handlers/package_utils.py +397 -28
- aiq/runtime/loader.py +23 -2
- aiq/tool/code_execution/README.md +0 -1
- aiq/tool/server_tools.py +1 -1
- aiq/utils/dump_distro_mapping.py +32 -0
- aiq/utils/type_converter.py +52 -10
- {aiqtoolkit-1.2.0rc2.dist-info → aiqtoolkit-1.2.0rc4.dist-info}/METADATA +1 -1
- {aiqtoolkit-1.2.0rc2.dist-info → aiqtoolkit-1.2.0rc4.dist-info}/RECORD +37 -35
- {aiqtoolkit-1.2.0rc2.dist-info → aiqtoolkit-1.2.0rc4.dist-info}/WHEEL +0 -0
- {aiqtoolkit-1.2.0rc2.dist-info → aiqtoolkit-1.2.0rc4.dist-info}/entry_points.txt +0 -0
- {aiqtoolkit-1.2.0rc2.dist-info → aiqtoolkit-1.2.0rc4.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {aiqtoolkit-1.2.0rc2.dist-info → aiqtoolkit-1.2.0rc4.dist-info}/licenses/LICENSE.md +0 -0
- {aiqtoolkit-1.2.0rc2.dist-info → aiqtoolkit-1.2.0rc4.dist-info}/top_level.txt +0 -0
aiq/agent/base.py
CHANGED
|
@@ -148,7 +148,7 @@ class BaseAgent(ABC):
|
|
|
148
148
|
"""
|
|
149
149
|
last_exception = None
|
|
150
150
|
|
|
151
|
-
for attempt in range(max_retries + 1):
|
|
151
|
+
for attempt in range(1, max_retries + 1):
|
|
152
152
|
try:
|
|
153
153
|
response = await tool.ainvoke(tool_input, config=config)
|
|
154
154
|
|
|
@@ -162,17 +162,18 @@ class BaseAgent(ABC):
|
|
|
162
162
|
|
|
163
163
|
except Exception as e:
|
|
164
164
|
last_exception = e
|
|
165
|
-
logger.warning("%s Tool call attempt %d/%d failed for tool %s: %s",
|
|
166
|
-
AGENT_LOG_PREFIX,
|
|
167
|
-
attempt + 1,
|
|
168
|
-
max_retries + 1,
|
|
169
|
-
tool.name,
|
|
170
|
-
str(e))
|
|
171
165
|
|
|
172
166
|
# If this was the last attempt, don't sleep
|
|
173
167
|
if attempt == max_retries:
|
|
174
168
|
break
|
|
175
169
|
|
|
170
|
+
logger.warning("%s Tool call attempt %d/%d failed for tool %s: %s",
|
|
171
|
+
AGENT_LOG_PREFIX,
|
|
172
|
+
attempt,
|
|
173
|
+
max_retries,
|
|
174
|
+
tool.name,
|
|
175
|
+
str(e))
|
|
176
|
+
|
|
176
177
|
# Exponential backoff: 2^attempt seconds
|
|
177
178
|
sleep_time = 2**attempt
|
|
178
179
|
logger.debug("%s Retrying tool call for %s in %d seconds...", AGENT_LOG_PREFIX, tool.name, sleep_time)
|
aiq/agent/react_agent/agent.py
CHANGED
|
@@ -193,12 +193,11 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
193
193
|
ex.observation,
|
|
194
194
|
output_message.content)
|
|
195
195
|
if attempt == self.parse_agent_response_max_retries:
|
|
196
|
-
logger.
|
|
196
|
+
logger.warning(
|
|
197
197
|
"%s Failed to parse agent output after %d attempts, consider enabling or "
|
|
198
198
|
"increasing parse_agent_response_max_retries",
|
|
199
199
|
AGENT_LOG_PREFIX,
|
|
200
|
-
attempt
|
|
201
|
-
exc_info=True)
|
|
200
|
+
attempt)
|
|
202
201
|
# the final answer goes in the "messages" state channel
|
|
203
202
|
combined_content = str(ex.observation) + '\n' + str(output_message.content)
|
|
204
203
|
output_message.content = combined_content
|
|
@@ -18,7 +18,6 @@ import logging
|
|
|
18
18
|
from pydantic import AliasChoices
|
|
19
19
|
from pydantic import Field
|
|
20
20
|
|
|
21
|
-
from aiq.agent.base import AGENT_LOG_PREFIX
|
|
22
21
|
from aiq.builder.builder import Builder
|
|
23
22
|
from aiq.builder.framework_enum import LLMFrameworkEnum
|
|
24
23
|
from aiq.builder.function_info import FunctionInfo
|
|
@@ -79,6 +78,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
79
78
|
from langchain_core.messages import trim_messages
|
|
80
79
|
from langgraph.graph.graph import CompiledGraph
|
|
81
80
|
|
|
81
|
+
from aiq.agent.base import AGENT_LOG_PREFIX
|
|
82
82
|
from aiq.agent.react_agent.agent import ReActAgentGraph
|
|
83
83
|
from aiq.agent.react_agent.agent import ReActGraphState
|
|
84
84
|
from aiq.agent.react_agent.agent import create_react_agent_prompt
|
|
@@ -19,7 +19,6 @@ from collections.abc import AsyncGenerator
|
|
|
19
19
|
|
|
20
20
|
from pydantic import Field
|
|
21
21
|
|
|
22
|
-
from aiq.agent.base import AGENT_LOG_PREFIX
|
|
23
22
|
from aiq.builder.builder import Builder
|
|
24
23
|
from aiq.builder.framework_enum import LLMFrameworkEnum
|
|
25
24
|
from aiq.builder.function_info import FunctionInfo
|
|
@@ -86,6 +85,8 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
|
|
|
86
85
|
from langchain_core.language_models import BaseChatModel
|
|
87
86
|
from langchain_core.prompts import PromptTemplate
|
|
88
87
|
|
|
88
|
+
from aiq.agent.base import AGENT_LOG_PREFIX
|
|
89
|
+
|
|
89
90
|
def remove_r1_think_tags(text: str):
|
|
90
91
|
pattern = r'(<think>)?.*?</think>\s*(.*)'
|
|
91
92
|
|
|
@@ -17,7 +17,6 @@ import logging
|
|
|
17
17
|
|
|
18
18
|
from pydantic import Field
|
|
19
19
|
|
|
20
|
-
from aiq.agent.base import AGENT_LOG_PREFIX
|
|
21
20
|
from aiq.builder.builder import Builder
|
|
22
21
|
from aiq.builder.framework_enum import LLMFrameworkEnum
|
|
23
22
|
from aiq.builder.function_info import FunctionInfo
|
|
@@ -49,6 +48,8 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
49
48
|
from langchain_core.messages.human import HumanMessage
|
|
50
49
|
from langgraph.graph.graph import CompiledGraph
|
|
51
50
|
|
|
51
|
+
from aiq.agent.base import AGENT_LOG_PREFIX
|
|
52
|
+
|
|
52
53
|
from .agent import ToolCallAgentGraph
|
|
53
54
|
from .agent import ToolCallAgentGraphState
|
|
54
55
|
|
|
@@ -28,9 +28,13 @@ logger = logging.getLogger(__name__)
|
|
|
28
28
|
|
|
29
29
|
class APIKeyAuthProvider(AuthProviderBase[APIKeyAuthProviderConfig]):
|
|
30
30
|
|
|
31
|
-
|
|
32
|
-
|
|
31
|
+
# fmt: off
|
|
32
|
+
def __init__(self,
|
|
33
|
+
config: APIKeyAuthProviderConfig,
|
|
34
|
+
config_name: str | None = None) -> None: # pylint: disable=unused-argument
|
|
35
|
+
assert isinstance(config, APIKeyAuthProviderConfig), ("Config is not APIKeyAuthProviderConfig")
|
|
33
36
|
super().__init__(config)
|
|
37
|
+
# fmt: on
|
|
34
38
|
|
|
35
39
|
async def _construct_authentication_header(self) -> BearerTokenCred:
|
|
36
40
|
"""
|
aiq/builder/function.py
CHANGED
|
@@ -76,11 +76,16 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
76
76
|
-------
|
|
77
77
|
_T
|
|
78
78
|
The converted value.
|
|
79
|
+
|
|
80
|
+
Raises
|
|
81
|
+
------
|
|
82
|
+
ValueError
|
|
83
|
+
If the value cannot be converted to the specified type (when `to_type` is specified).
|
|
79
84
|
"""
|
|
80
85
|
|
|
81
86
|
return self._converter.convert(value, to_type=to_type)
|
|
82
87
|
|
|
83
|
-
def try_convert(self, value: typing.Any, to_type: type[_T]) -> _T:
|
|
88
|
+
def try_convert(self, value: typing.Any, to_type: type[_T]) -> _T | typing.Any:
|
|
84
89
|
"""
|
|
85
90
|
Converts the given value to the specified type using graceful error handling.
|
|
86
91
|
If conversion fails, returns the original value and continues processing.
|
|
@@ -94,7 +99,7 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
94
99
|
|
|
95
100
|
Returns
|
|
96
101
|
-------
|
|
97
|
-
_T
|
|
102
|
+
_T | typing.Any
|
|
98
103
|
The converted value, or original value if conversion fails.
|
|
99
104
|
"""
|
|
100
105
|
return self._converter.try_convert(value, to_type=to_type)
|
|
@@ -129,17 +134,22 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
129
134
|
-------
|
|
130
135
|
typing.Any
|
|
131
136
|
The output of the function optionally converted to the specified type.
|
|
137
|
+
|
|
138
|
+
Raises
|
|
139
|
+
------
|
|
140
|
+
ValueError
|
|
141
|
+
If the output of the function cannot be converted to the specified type.
|
|
132
142
|
"""
|
|
133
143
|
|
|
134
144
|
with self._context.push_active_function(self.instance_name,
|
|
135
145
|
input_data=value) as manager: # Set the current invocation context
|
|
136
146
|
try:
|
|
137
|
-
converted_input: InputT = self._convert_input(value)
|
|
147
|
+
converted_input: InputT = self._convert_input(value)
|
|
138
148
|
|
|
139
149
|
result = await self._ainvoke(converted_input)
|
|
140
150
|
|
|
141
151
|
if to_type is not None and not isinstance(result, to_type):
|
|
142
|
-
result = self.
|
|
152
|
+
result = self.convert(result, to_type)
|
|
143
153
|
|
|
144
154
|
manager.set_output(result)
|
|
145
155
|
|
|
@@ -215,18 +225,23 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
215
225
|
------
|
|
216
226
|
typing.Any
|
|
217
227
|
The output of the function optionally converted to the specified type.
|
|
228
|
+
|
|
229
|
+
Raises
|
|
230
|
+
------
|
|
231
|
+
ValueError
|
|
232
|
+
If the output of the function cannot be converted to the specified type (when `to_type` is specified).
|
|
218
233
|
"""
|
|
219
234
|
|
|
220
235
|
with self._context.push_active_function(self.instance_name, input_data=value) as manager:
|
|
221
236
|
try:
|
|
222
|
-
converted_input: InputT = self._convert_input(value)
|
|
237
|
+
converted_input: InputT = self._convert_input(value)
|
|
223
238
|
|
|
224
239
|
# Collect streaming outputs to capture the final result
|
|
225
240
|
final_output: list[typing.Any] = []
|
|
226
241
|
|
|
227
242
|
async for data in self._astream(converted_input):
|
|
228
243
|
if to_type is not None and not isinstance(data, to_type):
|
|
229
|
-
converted_data = self.
|
|
244
|
+
converted_data = self.convert(data, to_type=to_type)
|
|
230
245
|
final_output.append(converted_data)
|
|
231
246
|
yield converted_data
|
|
232
247
|
else:
|
aiq/builder/function_base.py
CHANGED
|
@@ -350,7 +350,7 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
|
|
|
350
350
|
# output because the ABC has it.
|
|
351
351
|
return True
|
|
352
352
|
|
|
353
|
-
def _convert_input(self, value: typing.Any):
|
|
353
|
+
def _convert_input(self, value: typing.Any) -> InputT:
|
|
354
354
|
if (isinstance(value, self.input_class)):
|
|
355
355
|
return value
|
|
356
356
|
|
|
@@ -373,4 +373,8 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
|
|
|
373
373
|
return value
|
|
374
374
|
|
|
375
375
|
# Fallback to the converter
|
|
376
|
-
|
|
376
|
+
try:
|
|
377
|
+
return self._converter.convert(value, to_type=self.input_class)
|
|
378
|
+
except ValueError as e:
|
|
379
|
+
# Input parsing should yield a TypeError instead of a ValueError
|
|
380
|
+
raise TypeError from e
|
aiq/cli/commands/sizing/calc.py
CHANGED
|
@@ -274,9 +274,12 @@ def calc_command(ctx,
|
|
|
274
274
|
|
|
275
275
|
click.echo(tabulate(table, headers=headers, tablefmt="github"))
|
|
276
276
|
|
|
277
|
-
# Display slope-based GPU estimates
|
|
278
|
-
|
|
279
|
-
|
|
277
|
+
# Display slope-based GPU estimates if they are available
|
|
278
|
+
if results.gpu_estimates.gpu_estimate_by_llm_latency is not None or \
|
|
279
|
+
results.gpu_estimates.gpu_estimate_by_wf_runtime is not None:
|
|
280
|
+
click.echo("")
|
|
281
|
+
click.echo(click.style("=== GPU ESTIMATES ===", fg="bright_blue", bold=True))
|
|
282
|
+
|
|
280
283
|
if results.gpu_estimates.gpu_estimate_by_wf_runtime is not None:
|
|
281
284
|
click.echo(
|
|
282
285
|
click.style(
|
aiq/cli/commands/start.py
CHANGED
|
@@ -190,11 +190,6 @@ class StartCommandGroup(click.Group):
|
|
|
190
190
|
# Override default front end config with values from the config file for serverless execution modes.
|
|
191
191
|
# Check that we have the right kind of front end
|
|
192
192
|
if (not isinstance(config.general.front_end, front_end.config_type)):
|
|
193
|
-
logger.warning(
|
|
194
|
-
"The front end type in the config file (%s) does not match the command name (%s). "
|
|
195
|
-
"Overwriting the config file front end.",
|
|
196
|
-
config.general.front_end.type,
|
|
197
|
-
cmd_name)
|
|
198
193
|
|
|
199
194
|
# Set the front end config
|
|
200
195
|
config.general.front_end = front_end.config_type()
|
aiq/cli/commands/uninstall.py
CHANGED
|
@@ -53,13 +53,11 @@ async def uninstall_packages(packages: list[dict[str, str]]) -> None:
|
|
|
53
53
|
await stack.enter_async_context(registry_handler.remove(packages=package_name_list))
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
@click.group(name=__name__,
|
|
57
|
-
invoke_without_command=True,
|
|
58
|
-
help=("Uninstall an AIQ Toolkit plugin packages from the local environment."))
|
|
56
|
+
@click.group(name=__name__, invoke_without_command=True, help=("Uninstall plugin packages from the local environment."))
|
|
59
57
|
@click.argument("packages", type=str)
|
|
60
58
|
def uninstall_command(packages: str) -> None:
|
|
61
59
|
"""
|
|
62
|
-
Uninstall
|
|
60
|
+
Uninstall plugin packages from the local environment.
|
|
63
61
|
"""
|
|
64
62
|
|
|
65
63
|
packages = packages.split()
|
aiq/data_models/api_server.py
CHANGED
|
@@ -121,28 +121,22 @@ class AIQChatRequest(BaseModel):
|
|
|
121
121
|
# Optional fields (OpenAI Chat Completions API compatible)
|
|
122
122
|
model: str | None = Field(default=None, description="name of the model to use")
|
|
123
123
|
frequency_penalty: float | None = Field(default=0.0,
|
|
124
|
-
ge=-2.0,
|
|
125
|
-
le=2.0,
|
|
126
124
|
description="Penalty for new tokens based on frequency in text")
|
|
127
125
|
logit_bias: dict[str, float] | None = Field(default=None,
|
|
128
126
|
description="Modify likelihood of specified tokens appearing")
|
|
129
127
|
logprobs: bool | None = Field(default=None, description="Whether to return log probabilities")
|
|
130
|
-
top_logprobs: int | None = Field(default=None,
|
|
131
|
-
max_tokens: int | None = Field(default=None,
|
|
132
|
-
n: int | None = Field(default=1,
|
|
133
|
-
presence_penalty: float | None = Field(default=0.0,
|
|
134
|
-
ge=-2.0,
|
|
135
|
-
le=2.0,
|
|
136
|
-
description="Penalty for new tokens based on presence in text")
|
|
128
|
+
top_logprobs: int | None = Field(default=None, description="Number of most likely tokens to return")
|
|
129
|
+
max_tokens: int | None = Field(default=None, description="Maximum number of tokens to generate")
|
|
130
|
+
n: int | None = Field(default=1, description="Number of chat completion choices to generate")
|
|
131
|
+
presence_penalty: float | None = Field(default=0.0, description="Penalty for new tokens based on presence in text")
|
|
137
132
|
response_format: dict[str, typing.Any] | None = Field(default=None, description="Response format specification")
|
|
138
133
|
seed: int | None = Field(default=None, description="Random seed for deterministic sampling")
|
|
139
134
|
service_tier: typing.Literal["auto", "default"] | None = Field(default=None,
|
|
140
135
|
description="Service tier for the request")
|
|
141
|
-
stop: str | list[str] | None = Field(default=None, description="Up to 4 sequences where API will stop generating")
|
|
142
136
|
stream: bool | None = Field(default=False, description="Whether to stream partial message deltas")
|
|
143
137
|
stream_options: dict[str, typing.Any] | None = Field(default=None, description="Options for streaming")
|
|
144
|
-
temperature: float | None = Field(default=1.0,
|
|
145
|
-
top_p: float | None = Field(default=None,
|
|
138
|
+
temperature: float | None = Field(default=1.0, description="Sampling temperature between 0 and 2")
|
|
139
|
+
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
|
|
146
140
|
tools: list[dict[str, typing.Any]] | None = Field(default=None, description="List of tools the model may call")
|
|
147
141
|
tool_choice: str | dict[str, typing.Any] | None = Field(default=None, description="Controls which tool is called")
|
|
148
142
|
parallel_tool_calls: bool | None = Field(default=True, description="Whether to enable parallel function calling")
|
aiq/data_models/component_ref.py
CHANGED
|
@@ -21,6 +21,7 @@ import typing
|
|
|
21
21
|
from enum import Enum
|
|
22
22
|
from functools import lru_cache
|
|
23
23
|
from pathlib import Path
|
|
24
|
+
from types import ModuleType
|
|
24
25
|
from typing import TYPE_CHECKING
|
|
25
26
|
|
|
26
27
|
from pydantic import BaseModel
|
|
@@ -115,6 +116,55 @@ class DiscoveryMetadata(BaseModel):
|
|
|
115
116
|
return data.get(root_package, None)
|
|
116
117
|
return None
|
|
117
118
|
|
|
119
|
+
@staticmethod
|
|
120
|
+
@lru_cache
|
|
121
|
+
def get_distribution_name_from_module(module: ModuleType | None) -> str:
|
|
122
|
+
"""Get the distribution name from the config type using the mapping of module names to distro names.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
module (ModuleType): A registered component's module.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
str: The distribution name of the AIQ Toolkit component.
|
|
129
|
+
"""
|
|
130
|
+
from aiq.runtime.loader import get_all_aiq_entrypoints_distro_mapping
|
|
131
|
+
|
|
132
|
+
if module is None:
|
|
133
|
+
return "aiqtoolkit"
|
|
134
|
+
|
|
135
|
+
# Get the mapping of module names to distro names
|
|
136
|
+
mapping = get_all_aiq_entrypoints_distro_mapping()
|
|
137
|
+
module_package = module.__package__
|
|
138
|
+
|
|
139
|
+
if module_package is None:
|
|
140
|
+
return "aiqtoolkit"
|
|
141
|
+
|
|
142
|
+
# Traverse the module package parts in reverse order to find the distro name
|
|
143
|
+
# This is because the module package is the root package for the AIQ Toolkit component
|
|
144
|
+
# and the distro name is the name of the package that contains the component
|
|
145
|
+
module_package_parts = module_package.split(".")
|
|
146
|
+
for part_idx in range(len(module_package_parts), 0, -1):
|
|
147
|
+
candidate_module_name = ".".join(module_package_parts[0:part_idx])
|
|
148
|
+
candidate_distro_name = mapping.get(candidate_module_name, None)
|
|
149
|
+
if candidate_distro_name is not None:
|
|
150
|
+
return candidate_distro_name
|
|
151
|
+
|
|
152
|
+
return "aiqtoolkit"
|
|
153
|
+
|
|
154
|
+
@staticmethod
|
|
155
|
+
@lru_cache
|
|
156
|
+
def get_distribution_name_from_config_type(config_type: type["TypedBaseModelT"]) -> str:
|
|
157
|
+
"""Get the distribution name from the config type using the mapping of module names to distro names.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
config_type (type[TypedBaseModelT]): A registered component's configuration object.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
str: The distribution name of the AIQ Toolkit component.
|
|
164
|
+
"""
|
|
165
|
+
module = inspect.getmodule(config_type)
|
|
166
|
+
return DiscoveryMetadata.get_distribution_name_from_module(module)
|
|
167
|
+
|
|
118
168
|
@staticmethod
|
|
119
169
|
@lru_cache
|
|
120
170
|
def get_distribution_name(root_package: str) -> str:
|
|
@@ -123,6 +173,7 @@ class DiscoveryMetadata(BaseModel):
|
|
|
123
173
|
root package name 'aiq'. They provide mapping in a metadata file
|
|
124
174
|
for optimized installation.
|
|
125
175
|
"""
|
|
176
|
+
|
|
126
177
|
distro_name = DiscoveryMetadata.get_distribution_name_from_private_data(root_package)
|
|
127
178
|
return distro_name if distro_name else root_package
|
|
128
179
|
|
|
@@ -142,8 +193,7 @@ class DiscoveryMetadata(BaseModel):
|
|
|
142
193
|
|
|
143
194
|
try:
|
|
144
195
|
module = inspect.getmodule(config_type)
|
|
145
|
-
|
|
146
|
-
distro_name = DiscoveryMetadata.get_distribution_name(root_package)
|
|
196
|
+
distro_name = DiscoveryMetadata.get_distribution_name_from_config_type(config_type)
|
|
147
197
|
|
|
148
198
|
if not distro_name:
|
|
149
199
|
# raise an exception
|
|
@@ -187,12 +237,13 @@ class DiscoveryMetadata(BaseModel):
|
|
|
187
237
|
|
|
188
238
|
try:
|
|
189
239
|
module = inspect.getmodule(fn)
|
|
190
|
-
|
|
191
|
-
|
|
240
|
+
distro_name = DiscoveryMetadata.get_distribution_name_from_module(module)
|
|
241
|
+
|
|
192
242
|
try:
|
|
193
|
-
version = importlib.metadata.version(root_package) if root_package != "" else ""
|
|
243
|
+
# version = importlib.metadata.version(root_package) if root_package != "" else ""
|
|
244
|
+
version = importlib.metadata.version(distro_name) if distro_name != "" else ""
|
|
194
245
|
except importlib.metadata.PackageNotFoundError:
|
|
195
|
-
logger.warning("Package metadata not found for %s",
|
|
246
|
+
logger.warning("Package metadata not found for %s", distro_name)
|
|
196
247
|
version = ""
|
|
197
248
|
except Exception as e:
|
|
198
249
|
logger.exception("Encountered issue extracting module metadata for %s: %s", fn, e, exc_info=True)
|
|
@@ -201,7 +252,7 @@ class DiscoveryMetadata(BaseModel):
|
|
|
201
252
|
if isinstance(wrapper_type, LLMFrameworkEnum):
|
|
202
253
|
wrapper_type = wrapper_type.value
|
|
203
254
|
|
|
204
|
-
return DiscoveryMetadata(package=
|
|
255
|
+
return DiscoveryMetadata(package=distro_name,
|
|
205
256
|
version=version,
|
|
206
257
|
component_type=component_type,
|
|
207
258
|
component_name=wrapper_type,
|
|
@@ -220,7 +271,6 @@ class DiscoveryMetadata(BaseModel):
|
|
|
220
271
|
"""
|
|
221
272
|
|
|
222
273
|
try:
|
|
223
|
-
package_name = DiscoveryMetadata.get_distribution_name(package_name)
|
|
224
274
|
try:
|
|
225
275
|
metadata = importlib.metadata.metadata(package_name)
|
|
226
276
|
description = metadata.get("Summary", "")
|
|
@@ -263,12 +313,11 @@ class DiscoveryMetadata(BaseModel):
|
|
|
263
313
|
|
|
264
314
|
try:
|
|
265
315
|
module = inspect.getmodule(config_type)
|
|
266
|
-
|
|
267
|
-
root_package = DiscoveryMetadata.get_distribution_name(root_package)
|
|
316
|
+
distro_name = DiscoveryMetadata.get_distribution_name_from_module(module)
|
|
268
317
|
try:
|
|
269
|
-
version = importlib.metadata.version(
|
|
318
|
+
version = importlib.metadata.version(distro_name) if distro_name != "" else ""
|
|
270
319
|
except importlib.metadata.PackageNotFoundError:
|
|
271
|
-
logger.warning("Package metadata not found for %s",
|
|
320
|
+
logger.warning("Package metadata not found for %s", distro_name)
|
|
272
321
|
version = ""
|
|
273
322
|
except Exception as e:
|
|
274
323
|
logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e, exc_info=True)
|
|
@@ -279,7 +328,7 @@ class DiscoveryMetadata(BaseModel):
|
|
|
279
328
|
|
|
280
329
|
description = generate_config_type_docs(config_type=config_type)
|
|
281
330
|
|
|
282
|
-
return DiscoveryMetadata(package=
|
|
331
|
+
return DiscoveryMetadata(package=distro_name,
|
|
283
332
|
version=version,
|
|
284
333
|
component_type=component_type,
|
|
285
334
|
component_name=component_name,
|
|
@@ -15,12 +15,10 @@
|
|
|
15
15
|
|
|
16
16
|
import asyncio
|
|
17
17
|
import logging
|
|
18
|
-
from io import StringIO
|
|
19
18
|
|
|
20
19
|
import click
|
|
21
20
|
from colorama import Fore
|
|
22
21
|
|
|
23
|
-
from aiq.builder.workflow_builder import WorkflowBuilder
|
|
24
22
|
from aiq.data_models.interactive import HumanPromptModelType
|
|
25
23
|
from aiq.data_models.interactive import HumanResponse
|
|
26
24
|
from aiq.data_models.interactive import HumanResponseText
|
|
@@ -61,27 +59,9 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
|
|
|
61
59
|
if (not self.front_end_config.input_query and not self.front_end_config.input_file):
|
|
62
60
|
raise click.UsageError("Must specify either --input_query or --input_file")
|
|
63
61
|
|
|
64
|
-
async def
|
|
65
|
-
|
|
66
|
-
# Must yield the workflow function otherwise it cleans up
|
|
67
|
-
async with WorkflowBuilder.from_config(config=self.full_config) as builder:
|
|
68
|
-
|
|
69
|
-
session_manager: AIQSessionManager = None
|
|
70
|
-
|
|
71
|
-
if logger.isEnabledFor(logging.INFO):
|
|
72
|
-
stream = StringIO()
|
|
73
|
-
|
|
74
|
-
self.full_config.print_summary(stream=stream)
|
|
75
|
-
|
|
76
|
-
click.echo(stream.getvalue())
|
|
77
|
-
|
|
78
|
-
workflow = builder.build()
|
|
79
|
-
session_manager = AIQSessionManager(workflow)
|
|
80
|
-
|
|
81
|
-
await self.run_workflow(session_manager)
|
|
82
|
-
|
|
83
|
-
async def run_workflow(self, session_manager: AIQSessionManager = None):
|
|
62
|
+
async def run_workflow(self, session_manager: AIQSessionManager):
|
|
84
63
|
|
|
64
|
+
assert session_manager is not None, "Session manager must be provided"
|
|
85
65
|
runner_outputs = None
|
|
86
66
|
|
|
87
67
|
if (self.front_end_config.input_query):
|
|
@@ -45,8 +45,10 @@ class SimpleFrontEndPluginBase(FrontEndBase[FrontEndConfigT], ABC):
|
|
|
45
45
|
|
|
46
46
|
click.echo(stream.getvalue())
|
|
47
47
|
|
|
48
|
-
|
|
48
|
+
workflow = builder.build()
|
|
49
|
+
session_manager = AIQSessionManager(workflow)
|
|
50
|
+
await self.run_workflow(session_manager)
|
|
49
51
|
|
|
50
52
|
@abstractmethod
|
|
51
|
-
async def run_workflow(self, session_manager: AIQSessionManager
|
|
53
|
+
async def run_workflow(self, session_manager: AIQSessionManager):
|
|
52
54
|
pass
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
17
|
+
|
|
16
18
|
from aiq.builder.builder import Builder
|
|
17
19
|
from aiq.cli.register_workflow import register_object_store
|
|
18
20
|
from aiq.data_models.object_store import KeyAlreadyExistsError
|
|
@@ -37,37 +39,37 @@ class InMemoryObjectStore(ObjectStore):
|
|
|
37
39
|
"""
|
|
38
40
|
|
|
39
41
|
def __init__(self) -> None:
|
|
42
|
+
self._lock = asyncio.Lock()
|
|
40
43
|
self._store: dict[str, ObjectStoreItem] = {}
|
|
41
44
|
|
|
42
45
|
@override
|
|
43
46
|
async def put_object(self, key: str, item: ObjectStoreItem) -> None:
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
return
|
|
47
|
+
async with self._lock:
|
|
48
|
+
if key in self._store:
|
|
49
|
+
raise KeyAlreadyExistsError(key)
|
|
50
|
+
self._store[key] = item
|
|
49
51
|
|
|
50
52
|
@override
|
|
51
53
|
async def upsert_object(self, key: str, item: ObjectStoreItem) -> None:
|
|
52
|
-
self.
|
|
53
|
-
|
|
54
|
+
async with self._lock:
|
|
55
|
+
self._store[key] = item
|
|
54
56
|
|
|
55
57
|
@override
|
|
56
58
|
async def get_object(self, key: str) -> ObjectStoreItem:
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
59
|
+
async with self._lock:
|
|
60
|
+
value = self._store.get(key)
|
|
61
|
+
if value is None:
|
|
62
|
+
raise NoSuchKeyError(key)
|
|
63
|
+
return value
|
|
62
64
|
|
|
63
65
|
@override
|
|
64
66
|
async def delete_object(self, key: str) -> None:
|
|
65
|
-
|
|
67
|
+
try:
|
|
68
|
+
async with self._lock:
|
|
69
|
+
self._store.pop(key)
|
|
70
|
+
except KeyError:
|
|
66
71
|
raise NoSuchKeyError(key)
|
|
67
72
|
|
|
68
|
-
self._store.pop(key)
|
|
69
|
-
return
|
|
70
|
-
|
|
71
73
|
|
|
72
74
|
@register_object_store(config_type=InMemoryObjectStoreConfig)
|
|
73
75
|
async def in_memory_object_store(config: InMemoryObjectStoreConfig, builder: Builder):
|