aiqtoolkit 1.2.0rc2__py3-none-any.whl → 1.2.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aiqtoolkit might be problematic. Click here for more details.

Files changed (37) hide show
  1. aiq/agent/base.py +8 -7
  2. aiq/agent/react_agent/agent.py +2 -3
  3. aiq/agent/react_agent/register.py +1 -1
  4. aiq/agent/reasoning_agent/reasoning_agent.py +2 -1
  5. aiq/agent/tool_calling_agent/register.py +2 -1
  6. aiq/authentication/api_key/api_key_auth_provider.py +6 -2
  7. aiq/builder/function.py +21 -6
  8. aiq/builder/function_base.py +6 -2
  9. aiq/cli/commands/sizing/calc.py +6 -3
  10. aiq/cli/commands/start.py +0 -5
  11. aiq/cli/commands/uninstall.py +2 -4
  12. aiq/data_models/api_server.py +6 -12
  13. aiq/data_models/component_ref.py +1 -1
  14. aiq/data_models/discovery_metadata.py +62 -13
  15. aiq/front_ends/console/console_front_end_plugin.py +2 -22
  16. aiq/front_ends/simple_base/simple_front_end_plugin_base.py +4 -2
  17. aiq/object_store/in_memory_object_store.py +18 -16
  18. aiq/observability/exporter/processing_exporter.py +99 -46
  19. aiq/observability/exporter/span_exporter.py +1 -0
  20. aiq/observability/processor/batching_processor.py +52 -59
  21. aiq/observability/processor/callback_processor.py +42 -0
  22. aiq/observability/processor/processor.py +4 -1
  23. aiq/profiler/calc/calc_runner.py +5 -1
  24. aiq/profiler/calc/data_models.py +18 -6
  25. aiq/registry_handlers/package_utils.py +397 -28
  26. aiq/runtime/loader.py +23 -2
  27. aiq/tool/code_execution/README.md +0 -1
  28. aiq/tool/server_tools.py +1 -1
  29. aiq/utils/dump_distro_mapping.py +32 -0
  30. aiq/utils/type_converter.py +52 -10
  31. {aiqtoolkit-1.2.0rc2.dist-info → aiqtoolkit-1.2.0rc4.dist-info}/METADATA +1 -1
  32. {aiqtoolkit-1.2.0rc2.dist-info → aiqtoolkit-1.2.0rc4.dist-info}/RECORD +37 -35
  33. {aiqtoolkit-1.2.0rc2.dist-info → aiqtoolkit-1.2.0rc4.dist-info}/WHEEL +0 -0
  34. {aiqtoolkit-1.2.0rc2.dist-info → aiqtoolkit-1.2.0rc4.dist-info}/entry_points.txt +0 -0
  35. {aiqtoolkit-1.2.0rc2.dist-info → aiqtoolkit-1.2.0rc4.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  36. {aiqtoolkit-1.2.0rc2.dist-info → aiqtoolkit-1.2.0rc4.dist-info}/licenses/LICENSE.md +0 -0
  37. {aiqtoolkit-1.2.0rc2.dist-info → aiqtoolkit-1.2.0rc4.dist-info}/top_level.txt +0 -0
aiq/agent/base.py CHANGED
@@ -148,7 +148,7 @@ class BaseAgent(ABC):
148
148
  """
149
149
  last_exception = None
150
150
 
151
- for attempt in range(max_retries + 1):
151
+ for attempt in range(1, max_retries + 1):
152
152
  try:
153
153
  response = await tool.ainvoke(tool_input, config=config)
154
154
 
@@ -162,17 +162,18 @@ class BaseAgent(ABC):
162
162
 
163
163
  except Exception as e:
164
164
  last_exception = e
165
- logger.warning("%s Tool call attempt %d/%d failed for tool %s: %s",
166
- AGENT_LOG_PREFIX,
167
- attempt + 1,
168
- max_retries + 1,
169
- tool.name,
170
- str(e))
171
165
 
172
166
  # If this was the last attempt, don't sleep
173
167
  if attempt == max_retries:
174
168
  break
175
169
 
170
+ logger.warning("%s Tool call attempt %d/%d failed for tool %s: %s",
171
+ AGENT_LOG_PREFIX,
172
+ attempt,
173
+ max_retries,
174
+ tool.name,
175
+ str(e))
176
+
176
177
  # Exponential backoff: 2^attempt seconds
177
178
  sleep_time = 2**attempt
178
179
  logger.debug("%s Retrying tool call for %s in %d seconds...", AGENT_LOG_PREFIX, tool.name, sleep_time)
@@ -193,12 +193,11 @@ class ReActAgentGraph(DualNodeAgent):
193
193
  ex.observation,
194
194
  output_message.content)
195
195
  if attempt == self.parse_agent_response_max_retries:
196
- logger.error(
196
+ logger.warning(
197
197
  "%s Failed to parse agent output after %d attempts, consider enabling or "
198
198
  "increasing parse_agent_response_max_retries",
199
199
  AGENT_LOG_PREFIX,
200
- attempt,
201
- exc_info=True)
200
+ attempt)
202
201
  # the final answer goes in the "messages" state channel
203
202
  combined_content = str(ex.observation) + '\n' + str(output_message.content)
204
203
  output_message.content = combined_content
@@ -18,7 +18,6 @@ import logging
18
18
  from pydantic import AliasChoices
19
19
  from pydantic import Field
20
20
 
21
- from aiq.agent.base import AGENT_LOG_PREFIX
22
21
  from aiq.builder.builder import Builder
23
22
  from aiq.builder.framework_enum import LLMFrameworkEnum
24
23
  from aiq.builder.function_info import FunctionInfo
@@ -79,6 +78,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
79
78
  from langchain_core.messages import trim_messages
80
79
  from langgraph.graph.graph import CompiledGraph
81
80
 
81
+ from aiq.agent.base import AGENT_LOG_PREFIX
82
82
  from aiq.agent.react_agent.agent import ReActAgentGraph
83
83
  from aiq.agent.react_agent.agent import ReActGraphState
84
84
  from aiq.agent.react_agent.agent import create_react_agent_prompt
@@ -19,7 +19,6 @@ from collections.abc import AsyncGenerator
19
19
 
20
20
  from pydantic import Field
21
21
 
22
- from aiq.agent.base import AGENT_LOG_PREFIX
23
22
  from aiq.builder.builder import Builder
24
23
  from aiq.builder.framework_enum import LLMFrameworkEnum
25
24
  from aiq.builder.function_info import FunctionInfo
@@ -86,6 +85,8 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
86
85
  from langchain_core.language_models import BaseChatModel
87
86
  from langchain_core.prompts import PromptTemplate
88
87
 
88
+ from aiq.agent.base import AGENT_LOG_PREFIX
89
+
89
90
  def remove_r1_think_tags(text: str):
90
91
  pattern = r'(<think>)?.*?</think>\s*(.*)'
91
92
 
@@ -17,7 +17,6 @@ import logging
17
17
 
18
18
  from pydantic import Field
19
19
 
20
- from aiq.agent.base import AGENT_LOG_PREFIX
21
20
  from aiq.builder.builder import Builder
22
21
  from aiq.builder.framework_enum import LLMFrameworkEnum
23
22
  from aiq.builder.function_info import FunctionInfo
@@ -49,6 +48,8 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
49
48
  from langchain_core.messages.human import HumanMessage
50
49
  from langgraph.graph.graph import CompiledGraph
51
50
 
51
+ from aiq.agent.base import AGENT_LOG_PREFIX
52
+
52
53
  from .agent import ToolCallAgentGraph
53
54
  from .agent import ToolCallAgentGraphState
54
55
 
@@ -28,9 +28,13 @@ logger = logging.getLogger(__name__)
28
28
 
29
29
  class APIKeyAuthProvider(AuthProviderBase[APIKeyAuthProviderConfig]):
30
30
 
31
- def __init__(self, config: APIKeyAuthProviderConfig, config_name: str | None = None) -> None:
32
- assert isinstance(config, APIKeyAuthProviderConfig), ("Config is not APIKeyConfig")
31
+ # fmt: off
32
+ def __init__(self,
33
+ config: APIKeyAuthProviderConfig,
34
+ config_name: str | None = None) -> None: # pylint: disable=unused-argument
35
+ assert isinstance(config, APIKeyAuthProviderConfig), ("Config is not APIKeyAuthProviderConfig")
33
36
  super().__init__(config)
37
+ # fmt: on
34
38
 
35
39
  async def _construct_authentication_header(self) -> BearerTokenCred:
36
40
  """
aiq/builder/function.py CHANGED
@@ -76,11 +76,16 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
76
76
  -------
77
77
  _T
78
78
  The converted value.
79
+
80
+ Raises
81
+ ------
82
+ ValueError
83
+ If the value cannot be converted to the specified type (when `to_type` is specified).
79
84
  """
80
85
 
81
86
  return self._converter.convert(value, to_type=to_type)
82
87
 
83
- def try_convert(self, value: typing.Any, to_type: type[_T]) -> _T:
88
+ def try_convert(self, value: typing.Any, to_type: type[_T]) -> _T | typing.Any:
84
89
  """
85
90
  Converts the given value to the specified type using graceful error handling.
86
91
  If conversion fails, returns the original value and continues processing.
@@ -94,7 +99,7 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
94
99
 
95
100
  Returns
96
101
  -------
97
- _T
102
+ _T | typing.Any
98
103
  The converted value, or original value if conversion fails.
99
104
  """
100
105
  return self._converter.try_convert(value, to_type=to_type)
@@ -129,17 +134,22 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
129
134
  -------
130
135
  typing.Any
131
136
  The output of the function optionally converted to the specified type.
137
+
138
+ Raises
139
+ ------
140
+ ValueError
141
+ If the output of the function cannot be converted to the specified type.
132
142
  """
133
143
 
134
144
  with self._context.push_active_function(self.instance_name,
135
145
  input_data=value) as manager: # Set the current invocation context
136
146
  try:
137
- converted_input: InputT = self._convert_input(value) # type: ignore
147
+ converted_input: InputT = self._convert_input(value)
138
148
 
139
149
  result = await self._ainvoke(converted_input)
140
150
 
141
151
  if to_type is not None and not isinstance(result, to_type):
142
- result = self._converter.try_convert(result, to_type=to_type)
152
+ result = self.convert(result, to_type)
143
153
 
144
154
  manager.set_output(result)
145
155
 
@@ -215,18 +225,23 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
215
225
  ------
216
226
  typing.Any
217
227
  The output of the function optionally converted to the specified type.
228
+
229
+ Raises
230
+ ------
231
+ ValueError
232
+ If the output of the function cannot be converted to the specified type (when `to_type` is specified).
218
233
  """
219
234
 
220
235
  with self._context.push_active_function(self.instance_name, input_data=value) as manager:
221
236
  try:
222
- converted_input: InputT = self._convert_input(value) # type: ignore
237
+ converted_input: InputT = self._convert_input(value)
223
238
 
224
239
  # Collect streaming outputs to capture the final result
225
240
  final_output: list[typing.Any] = []
226
241
 
227
242
  async for data in self._astream(converted_input):
228
243
  if to_type is not None and not isinstance(data, to_type):
229
- converted_data = self._converter.try_convert(data, to_type=to_type)
244
+ converted_data = self.convert(data, to_type=to_type)
230
245
  final_output.append(converted_data)
231
246
  yield converted_data
232
247
  else:
@@ -350,7 +350,7 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
350
350
  # output because the ABC has it.
351
351
  return True
352
352
 
353
- def _convert_input(self, value: typing.Any):
353
+ def _convert_input(self, value: typing.Any) -> InputT:
354
354
  if (isinstance(value, self.input_class)):
355
355
  return value
356
356
 
@@ -373,4 +373,8 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
373
373
  return value
374
374
 
375
375
  # Fallback to the converter
376
- return self._converter.try_convert(value, to_type=self.input_class)
376
+ try:
377
+ return self._converter.convert(value, to_type=self.input_class)
378
+ except ValueError as e:
379
+ # Input parsing should yield a TypeError instead of a ValueError
380
+ raise TypeError from e
@@ -274,9 +274,12 @@ def calc_command(ctx,
274
274
 
275
275
  click.echo(tabulate(table, headers=headers, tablefmt="github"))
276
276
 
277
- # Display slope-based GPU estimates at the end
278
- click.echo("") # Add blank line for separation
279
- click.echo(click.style("=== GPU ESTIMATES ===", fg="bright_blue", bold=True))
277
+ # Display slope-based GPU estimates if they are available
278
+ if results.gpu_estimates.gpu_estimate_by_llm_latency is not None or \
279
+ results.gpu_estimates.gpu_estimate_by_wf_runtime is not None:
280
+ click.echo("")
281
+ click.echo(click.style("=== GPU ESTIMATES ===", fg="bright_blue", bold=True))
282
+
280
283
  if results.gpu_estimates.gpu_estimate_by_wf_runtime is not None:
281
284
  click.echo(
282
285
  click.style(
aiq/cli/commands/start.py CHANGED
@@ -190,11 +190,6 @@ class StartCommandGroup(click.Group):
190
190
  # Override default front end config with values from the config file for serverless execution modes.
191
191
  # Check that we have the right kind of front end
192
192
  if (not isinstance(config.general.front_end, front_end.config_type)):
193
- logger.warning(
194
- "The front end type in the config file (%s) does not match the command name (%s). "
195
- "Overwriting the config file front end.",
196
- config.general.front_end.type,
197
- cmd_name)
198
193
 
199
194
  # Set the front end config
200
195
  config.general.front_end = front_end.config_type()
@@ -53,13 +53,11 @@ async def uninstall_packages(packages: list[dict[str, str]]) -> None:
53
53
  await stack.enter_async_context(registry_handler.remove(packages=package_name_list))
54
54
 
55
55
 
56
- @click.group(name=__name__,
57
- invoke_without_command=True,
58
- help=("Uninstall an AIQ Toolkit plugin packages from the local environment."))
56
+ @click.group(name=__name__, invoke_without_command=True, help=("Uninstall plugin packages from the local environment."))
59
57
  @click.argument("packages", type=str)
60
58
  def uninstall_command(packages: str) -> None:
61
59
  """
62
- Uninstall AIQ Toolkit plugin packages from the local environment.
60
+ Uninstall plugin packages from the local environment.
63
61
  """
64
62
 
65
63
  packages = packages.split()
@@ -121,28 +121,22 @@ class AIQChatRequest(BaseModel):
121
121
  # Optional fields (OpenAI Chat Completions API compatible)
122
122
  model: str | None = Field(default=None, description="name of the model to use")
123
123
  frequency_penalty: float | None = Field(default=0.0,
124
- ge=-2.0,
125
- le=2.0,
126
124
  description="Penalty for new tokens based on frequency in text")
127
125
  logit_bias: dict[str, float] | None = Field(default=None,
128
126
  description="Modify likelihood of specified tokens appearing")
129
127
  logprobs: bool | None = Field(default=None, description="Whether to return log probabilities")
130
- top_logprobs: int | None = Field(default=None, ge=0, le=20, description="Number of most likely tokens to return")
131
- max_tokens: int | None = Field(default=None, ge=1, description="Maximum number of tokens to generate")
132
- n: int | None = Field(default=1, ge=1, le=128, description="Number of chat completion choices to generate")
133
- presence_penalty: float | None = Field(default=0.0,
134
- ge=-2.0,
135
- le=2.0,
136
- description="Penalty for new tokens based on presence in text")
128
+ top_logprobs: int | None = Field(default=None, description="Number of most likely tokens to return")
129
+ max_tokens: int | None = Field(default=None, description="Maximum number of tokens to generate")
130
+ n: int | None = Field(default=1, description="Number of chat completion choices to generate")
131
+ presence_penalty: float | None = Field(default=0.0, description="Penalty for new tokens based on presence in text")
137
132
  response_format: dict[str, typing.Any] | None = Field(default=None, description="Response format specification")
138
133
  seed: int | None = Field(default=None, description="Random seed for deterministic sampling")
139
134
  service_tier: typing.Literal["auto", "default"] | None = Field(default=None,
140
135
  description="Service tier for the request")
141
- stop: str | list[str] | None = Field(default=None, description="Up to 4 sequences where API will stop generating")
142
136
  stream: bool | None = Field(default=False, description="Whether to stream partial message deltas")
143
137
  stream_options: dict[str, typing.Any] | None = Field(default=None, description="Options for streaming")
144
- temperature: float | None = Field(default=1.0, ge=0.0, le=2.0, description="Sampling temperature between 0 and 2")
145
- top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="Nucleus sampling parameter")
138
+ temperature: float | None = Field(default=1.0, description="Sampling temperature between 0 and 2")
139
+ top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
146
140
  tools: list[dict[str, typing.Any]] | None = Field(default=None, description="List of tools the model may call")
147
141
  tool_choice: str | dict[str, typing.Any] | None = Field(default=None, description="Controls which tool is called")
148
142
  parallel_tool_calls: bool | None = Field(default=True, description="Whether to enable parallel function calling")
@@ -130,7 +130,7 @@ class ObjectStoreRef(ComponentRef):
130
130
  """
131
131
 
132
132
  @property
133
- @typing.override
133
+ @override
134
134
  def component_group(self):
135
135
  return ComponentGroup.OBJECT_STORES
136
136
 
@@ -21,6 +21,7 @@ import typing
21
21
  from enum import Enum
22
22
  from functools import lru_cache
23
23
  from pathlib import Path
24
+ from types import ModuleType
24
25
  from typing import TYPE_CHECKING
25
26
 
26
27
  from pydantic import BaseModel
@@ -115,6 +116,55 @@ class DiscoveryMetadata(BaseModel):
115
116
  return data.get(root_package, None)
116
117
  return None
117
118
 
119
+ @staticmethod
120
+ @lru_cache
121
+ def get_distribution_name_from_module(module: ModuleType | None) -> str:
122
+ """Get the distribution name from the config type using the mapping of module names to distro names.
123
+
124
+ Args:
125
+ module (ModuleType): A registered component's module.
126
+
127
+ Returns:
128
+ str: The distribution name of the AIQ Toolkit component.
129
+ """
130
+ from aiq.runtime.loader import get_all_aiq_entrypoints_distro_mapping
131
+
132
+ if module is None:
133
+ return "aiqtoolkit"
134
+
135
+ # Get the mapping of module names to distro names
136
+ mapping = get_all_aiq_entrypoints_distro_mapping()
137
+ module_package = module.__package__
138
+
139
+ if module_package is None:
140
+ return "aiqtoolkit"
141
+
142
+ # Traverse the module package parts in reverse order to find the distro name
143
+ # This is because the module package is the root package for the AIQ Toolkit component
144
+ # and the distro name is the name of the package that contains the component
145
+ module_package_parts = module_package.split(".")
146
+ for part_idx in range(len(module_package_parts), 0, -1):
147
+ candidate_module_name = ".".join(module_package_parts[0:part_idx])
148
+ candidate_distro_name = mapping.get(candidate_module_name, None)
149
+ if candidate_distro_name is not None:
150
+ return candidate_distro_name
151
+
152
+ return "aiqtoolkit"
153
+
154
+ @staticmethod
155
+ @lru_cache
156
+ def get_distribution_name_from_config_type(config_type: type["TypedBaseModelT"]) -> str:
157
+ """Get the distribution name from the config type using the mapping of module names to distro names.
158
+
159
+ Args:
160
+ config_type (type[TypedBaseModelT]): A registered component's configuration object.
161
+
162
+ Returns:
163
+ str: The distribution name of the AIQ Toolkit component.
164
+ """
165
+ module = inspect.getmodule(config_type)
166
+ return DiscoveryMetadata.get_distribution_name_from_module(module)
167
+
118
168
  @staticmethod
119
169
  @lru_cache
120
170
  def get_distribution_name(root_package: str) -> str:
@@ -123,6 +173,7 @@ class DiscoveryMetadata(BaseModel):
123
173
  root package name 'aiq'. They provide mapping in a metadata file
124
174
  for optimized installation.
125
175
  """
176
+
126
177
  distro_name = DiscoveryMetadata.get_distribution_name_from_private_data(root_package)
127
178
  return distro_name if distro_name else root_package
128
179
 
@@ -142,8 +193,7 @@ class DiscoveryMetadata(BaseModel):
142
193
 
143
194
  try:
144
195
  module = inspect.getmodule(config_type)
145
- root_package: str = module.__package__.split(".")[0]
146
- distro_name = DiscoveryMetadata.get_distribution_name(root_package)
196
+ distro_name = DiscoveryMetadata.get_distribution_name_from_config_type(config_type)
147
197
 
148
198
  if not distro_name:
149
199
  # raise an exception
@@ -187,12 +237,13 @@ class DiscoveryMetadata(BaseModel):
187
237
 
188
238
  try:
189
239
  module = inspect.getmodule(fn)
190
- root_package: str = module.__package__.split(".")[0]
191
- root_package = DiscoveryMetadata.get_distribution_name(root_package)
240
+ distro_name = DiscoveryMetadata.get_distribution_name_from_module(module)
241
+
192
242
  try:
193
- version = importlib.metadata.version(root_package) if root_package != "" else ""
243
+ # version = importlib.metadata.version(root_package) if root_package != "" else ""
244
+ version = importlib.metadata.version(distro_name) if distro_name != "" else ""
194
245
  except importlib.metadata.PackageNotFoundError:
195
- logger.warning("Package metadata not found for %s", root_package)
246
+ logger.warning("Package metadata not found for %s", distro_name)
196
247
  version = ""
197
248
  except Exception as e:
198
249
  logger.exception("Encountered issue extracting module metadata for %s: %s", fn, e, exc_info=True)
@@ -201,7 +252,7 @@ class DiscoveryMetadata(BaseModel):
201
252
  if isinstance(wrapper_type, LLMFrameworkEnum):
202
253
  wrapper_type = wrapper_type.value
203
254
 
204
- return DiscoveryMetadata(package=root_package,
255
+ return DiscoveryMetadata(package=distro_name,
205
256
  version=version,
206
257
  component_type=component_type,
207
258
  component_name=wrapper_type,
@@ -220,7 +271,6 @@ class DiscoveryMetadata(BaseModel):
220
271
  """
221
272
 
222
273
  try:
223
- package_name = DiscoveryMetadata.get_distribution_name(package_name)
224
274
  try:
225
275
  metadata = importlib.metadata.metadata(package_name)
226
276
  description = metadata.get("Summary", "")
@@ -263,12 +313,11 @@ class DiscoveryMetadata(BaseModel):
263
313
 
264
314
  try:
265
315
  module = inspect.getmodule(config_type)
266
- root_package: str = module.__package__.split(".")[0]
267
- root_package = DiscoveryMetadata.get_distribution_name(root_package)
316
+ distro_name = DiscoveryMetadata.get_distribution_name_from_module(module)
268
317
  try:
269
- version = importlib.metadata.version(root_package) if root_package != "" else ""
318
+ version = importlib.metadata.version(distro_name) if distro_name != "" else ""
270
319
  except importlib.metadata.PackageNotFoundError:
271
- logger.warning("Package metadata not found for %s", root_package)
320
+ logger.warning("Package metadata not found for %s", distro_name)
272
321
  version = ""
273
322
  except Exception as e:
274
323
  logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e, exc_info=True)
@@ -279,7 +328,7 @@ class DiscoveryMetadata(BaseModel):
279
328
 
280
329
  description = generate_config_type_docs(config_type=config_type)
281
330
 
282
- return DiscoveryMetadata(package=root_package,
331
+ return DiscoveryMetadata(package=distro_name,
283
332
  version=version,
284
333
  component_type=component_type,
285
334
  component_name=component_name,
@@ -15,12 +15,10 @@
15
15
 
16
16
  import asyncio
17
17
  import logging
18
- from io import StringIO
19
18
 
20
19
  import click
21
20
  from colorama import Fore
22
21
 
23
- from aiq.builder.workflow_builder import WorkflowBuilder
24
22
  from aiq.data_models.interactive import HumanPromptModelType
25
23
  from aiq.data_models.interactive import HumanResponse
26
24
  from aiq.data_models.interactive import HumanResponseText
@@ -61,27 +59,9 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
61
59
  if (not self.front_end_config.input_query and not self.front_end_config.input_file):
62
60
  raise click.UsageError("Must specify either --input_query or --input_file")
63
61
 
64
- async def run(self):
65
-
66
- # Must yield the workflow function otherwise it cleans up
67
- async with WorkflowBuilder.from_config(config=self.full_config) as builder:
68
-
69
- session_manager: AIQSessionManager = None
70
-
71
- if logger.isEnabledFor(logging.INFO):
72
- stream = StringIO()
73
-
74
- self.full_config.print_summary(stream=stream)
75
-
76
- click.echo(stream.getvalue())
77
-
78
- workflow = builder.build()
79
- session_manager = AIQSessionManager(workflow)
80
-
81
- await self.run_workflow(session_manager)
82
-
83
- async def run_workflow(self, session_manager: AIQSessionManager = None):
62
+ async def run_workflow(self, session_manager: AIQSessionManager):
84
63
 
64
+ assert session_manager is not None, "Session manager must be provided"
85
65
  runner_outputs = None
86
66
 
87
67
  if (self.front_end_config.input_query):
@@ -45,8 +45,10 @@ class SimpleFrontEndPluginBase(FrontEndBase[FrontEndConfigT], ABC):
45
45
 
46
46
  click.echo(stream.getvalue())
47
47
 
48
- await self.run_workflow(builder.build())
48
+ workflow = builder.build()
49
+ session_manager = AIQSessionManager(workflow)
50
+ await self.run_workflow(session_manager)
49
51
 
50
52
  @abstractmethod
51
- async def run_workflow(self, session_manager: AIQSessionManager = None):
53
+ async def run_workflow(self, session_manager: AIQSessionManager):
52
54
  pass
@@ -13,6 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import asyncio
17
+
16
18
  from aiq.builder.builder import Builder
17
19
  from aiq.cli.register_workflow import register_object_store
18
20
  from aiq.data_models.object_store import KeyAlreadyExistsError
@@ -37,37 +39,37 @@ class InMemoryObjectStore(ObjectStore):
37
39
  """
38
40
 
39
41
  def __init__(self) -> None:
42
+ self._lock = asyncio.Lock()
40
43
  self._store: dict[str, ObjectStoreItem] = {}
41
44
 
42
45
  @override
43
46
  async def put_object(self, key: str, item: ObjectStoreItem) -> None:
44
- if key in self._store:
45
- raise KeyAlreadyExistsError(key)
46
-
47
- self._store[key] = item
48
- return
47
+ async with self._lock:
48
+ if key in self._store:
49
+ raise KeyAlreadyExistsError(key)
50
+ self._store[key] = item
49
51
 
50
52
  @override
51
53
  async def upsert_object(self, key: str, item: ObjectStoreItem) -> None:
52
- self._store[key] = item
53
- return
54
+ async with self._lock:
55
+ self._store[key] = item
54
56
 
55
57
  @override
56
58
  async def get_object(self, key: str) -> ObjectStoreItem:
57
-
58
- if key not in self._store:
59
- raise NoSuchKeyError(key)
60
-
61
- return self._store[key]
59
+ async with self._lock:
60
+ value = self._store.get(key)
61
+ if value is None:
62
+ raise NoSuchKeyError(key)
63
+ return value
62
64
 
63
65
  @override
64
66
  async def delete_object(self, key: str) -> None:
65
- if key not in self._store:
67
+ try:
68
+ async with self._lock:
69
+ self._store.pop(key)
70
+ except KeyError:
66
71
  raise NoSuchKeyError(key)
67
72
 
68
- self._store.pop(key)
69
- return
70
-
71
73
 
72
74
  @register_object_store(config_type=InMemoryObjectStoreConfig)
73
75
  async def in_memory_object_store(config: InMemoryObjectStoreConfig, builder: Builder):