alita-sdk 0.3.365__py3-none-any.whl → 0.3.462__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of alita-sdk might be problematic. Click here for more details.

Files changed (118) hide show
  1. alita_sdk/cli/__init__.py +10 -0
  2. alita_sdk/cli/__main__.py +17 -0
  3. alita_sdk/cli/agent_executor.py +144 -0
  4. alita_sdk/cli/agent_loader.py +197 -0
  5. alita_sdk/cli/agent_ui.py +166 -0
  6. alita_sdk/cli/agents.py +1069 -0
  7. alita_sdk/cli/callbacks.py +576 -0
  8. alita_sdk/cli/cli.py +159 -0
  9. alita_sdk/cli/config.py +153 -0
  10. alita_sdk/cli/formatting.py +182 -0
  11. alita_sdk/cli/mcp_loader.py +315 -0
  12. alita_sdk/cli/toolkit.py +330 -0
  13. alita_sdk/cli/toolkit_loader.py +55 -0
  14. alita_sdk/cli/tools/__init__.py +9 -0
  15. alita_sdk/cli/tools/filesystem.py +905 -0
  16. alita_sdk/configurations/bitbucket.py +95 -0
  17. alita_sdk/configurations/confluence.py +96 -1
  18. alita_sdk/configurations/gitlab.py +79 -0
  19. alita_sdk/configurations/jira.py +103 -0
  20. alita_sdk/configurations/testrail.py +88 -0
  21. alita_sdk/configurations/xray.py +93 -0
  22. alita_sdk/configurations/zephyr_enterprise.py +93 -0
  23. alita_sdk/configurations/zephyr_essential.py +75 -0
  24. alita_sdk/runtime/clients/artifact.py +1 -1
  25. alita_sdk/runtime/clients/client.py +47 -10
  26. alita_sdk/runtime/clients/mcp_discovery.py +342 -0
  27. alita_sdk/runtime/clients/mcp_manager.py +262 -0
  28. alita_sdk/runtime/clients/sandbox_client.py +373 -0
  29. alita_sdk/runtime/langchain/assistant.py +70 -41
  30. alita_sdk/runtime/langchain/constants.py +6 -1
  31. alita_sdk/runtime/langchain/document_loaders/AlitaDocxMammothLoader.py +315 -3
  32. alita_sdk/runtime/langchain/document_loaders/AlitaJSONLoader.py +4 -1
  33. alita_sdk/runtime/langchain/document_loaders/constants.py +73 -100
  34. alita_sdk/runtime/langchain/langraph_agent.py +164 -38
  35. alita_sdk/runtime/langchain/utils.py +43 -7
  36. alita_sdk/runtime/models/mcp_models.py +61 -0
  37. alita_sdk/runtime/toolkits/__init__.py +24 -0
  38. alita_sdk/runtime/toolkits/application.py +8 -1
  39. alita_sdk/runtime/toolkits/artifact.py +5 -6
  40. alita_sdk/runtime/toolkits/mcp.py +895 -0
  41. alita_sdk/runtime/toolkits/tools.py +140 -50
  42. alita_sdk/runtime/tools/__init__.py +7 -2
  43. alita_sdk/runtime/tools/application.py +7 -0
  44. alita_sdk/runtime/tools/function.py +94 -5
  45. alita_sdk/runtime/tools/graph.py +10 -4
  46. alita_sdk/runtime/tools/image_generation.py +104 -8
  47. alita_sdk/runtime/tools/llm.py +204 -114
  48. alita_sdk/runtime/tools/mcp_inspect_tool.py +284 -0
  49. alita_sdk/runtime/tools/mcp_remote_tool.py +166 -0
  50. alita_sdk/runtime/tools/mcp_server_tool.py +3 -1
  51. alita_sdk/runtime/tools/sandbox.py +180 -79
  52. alita_sdk/runtime/tools/vectorstore.py +22 -21
  53. alita_sdk/runtime/tools/vectorstore_base.py +79 -26
  54. alita_sdk/runtime/utils/mcp_oauth.py +164 -0
  55. alita_sdk/runtime/utils/mcp_sse_client.py +405 -0
  56. alita_sdk/runtime/utils/streamlit.py +34 -3
  57. alita_sdk/runtime/utils/toolkit_utils.py +14 -4
  58. alita_sdk/runtime/utils/utils.py +1 -0
  59. alita_sdk/tools/__init__.py +48 -31
  60. alita_sdk/tools/ado/repos/__init__.py +1 -0
  61. alita_sdk/tools/ado/test_plan/__init__.py +1 -1
  62. alita_sdk/tools/ado/wiki/__init__.py +1 -5
  63. alita_sdk/tools/ado/work_item/__init__.py +1 -5
  64. alita_sdk/tools/ado/work_item/ado_wrapper.py +17 -8
  65. alita_sdk/tools/base_indexer_toolkit.py +194 -112
  66. alita_sdk/tools/bitbucket/__init__.py +1 -0
  67. alita_sdk/tools/chunkers/sematic/proposal_chunker.py +1 -1
  68. alita_sdk/tools/code/sonar/__init__.py +1 -1
  69. alita_sdk/tools/code_indexer_toolkit.py +15 -5
  70. alita_sdk/tools/confluence/__init__.py +2 -2
  71. alita_sdk/tools/confluence/api_wrapper.py +110 -63
  72. alita_sdk/tools/confluence/loader.py +10 -0
  73. alita_sdk/tools/elitea_base.py +22 -22
  74. alita_sdk/tools/github/__init__.py +2 -2
  75. alita_sdk/tools/gitlab/__init__.py +2 -1
  76. alita_sdk/tools/gitlab/api_wrapper.py +11 -7
  77. alita_sdk/tools/gitlab_org/__init__.py +1 -2
  78. alita_sdk/tools/google_places/__init__.py +2 -1
  79. alita_sdk/tools/jira/__init__.py +1 -0
  80. alita_sdk/tools/jira/api_wrapper.py +1 -1
  81. alita_sdk/tools/memory/__init__.py +1 -1
  82. alita_sdk/tools/non_code_indexer_toolkit.py +2 -2
  83. alita_sdk/tools/openapi/__init__.py +10 -1
  84. alita_sdk/tools/pandas/__init__.py +1 -1
  85. alita_sdk/tools/postman/__init__.py +2 -1
  86. alita_sdk/tools/postman/api_wrapper.py +18 -8
  87. alita_sdk/tools/postman/postman_analysis.py +8 -1
  88. alita_sdk/tools/pptx/__init__.py +2 -2
  89. alita_sdk/tools/qtest/__init__.py +3 -3
  90. alita_sdk/tools/qtest/api_wrapper.py +1708 -76
  91. alita_sdk/tools/rally/__init__.py +1 -2
  92. alita_sdk/tools/report_portal/__init__.py +1 -0
  93. alita_sdk/tools/salesforce/__init__.py +1 -0
  94. alita_sdk/tools/servicenow/__init__.py +2 -3
  95. alita_sdk/tools/sharepoint/__init__.py +1 -0
  96. alita_sdk/tools/sharepoint/api_wrapper.py +125 -34
  97. alita_sdk/tools/sharepoint/authorization_helper.py +191 -1
  98. alita_sdk/tools/sharepoint/utils.py +8 -2
  99. alita_sdk/tools/slack/__init__.py +1 -0
  100. alita_sdk/tools/sql/__init__.py +2 -1
  101. alita_sdk/tools/sql/api_wrapper.py +71 -23
  102. alita_sdk/tools/testio/__init__.py +1 -0
  103. alita_sdk/tools/testrail/__init__.py +1 -3
  104. alita_sdk/tools/utils/__init__.py +17 -0
  105. alita_sdk/tools/utils/content_parser.py +35 -24
  106. alita_sdk/tools/vector_adapters/VectorStoreAdapter.py +67 -21
  107. alita_sdk/tools/xray/__init__.py +2 -1
  108. alita_sdk/tools/zephyr/__init__.py +2 -1
  109. alita_sdk/tools/zephyr_enterprise/__init__.py +1 -0
  110. alita_sdk/tools/zephyr_essential/__init__.py +1 -0
  111. alita_sdk/tools/zephyr_scale/__init__.py +1 -0
  112. alita_sdk/tools/zephyr_squad/__init__.py +1 -0
  113. {alita_sdk-0.3.365.dist-info → alita_sdk-0.3.462.dist-info}/METADATA +8 -2
  114. {alita_sdk-0.3.365.dist-info → alita_sdk-0.3.462.dist-info}/RECORD +118 -93
  115. alita_sdk-0.3.462.dist-info/entry_points.txt +2 -0
  116. {alita_sdk-0.3.365.dist-info → alita_sdk-0.3.462.dist-info}/WHEEL +0 -0
  117. {alita_sdk-0.3.365.dist-info → alita_sdk-0.3.462.dist-info}/licenses/LICENSE +0 -0
  118. {alita_sdk-0.3.365.dist-info → alita_sdk-0.3.462.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ from typing import Optional
2
3
 
3
4
  from langchain_core.tools import ToolException
4
5
  from langgraph.store.base import BaseStore
@@ -11,10 +12,14 @@ from .datasource import DatasourcesToolkit
11
12
  from .prompt import PromptToolkit
12
13
  from .subgraph import SubgraphToolkit
13
14
  from .vectorstore import VectorStoreToolkit
15
+ from .mcp import McpToolkit
14
16
  from ..tools.mcp_server_tool import McpServerTool
17
+ from ..tools.sandbox import SandboxToolkit
18
+ from ..tools.image_generation import ImageGenerationToolkit
15
19
  # Import community tools
16
20
  from ...community import get_toolkits as community_toolkits, get_tools as community_tools
17
21
  from ...tools.memory import MemoryToolkit
22
+ from ..utils.mcp_oauth import canonical_resource, McpAuthorizationRequired
18
23
  from ...tools.utils import TOOLKIT_SPLITTER
19
24
 
20
25
  logger = logging.getLogger(__name__)
@@ -24,64 +29,139 @@ def get_toolkits():
24
29
  core_toolkits = [
25
30
  ArtifactToolkit.toolkit_config_schema(),
26
31
  MemoryToolkit.toolkit_config_schema(),
27
- VectorStoreToolkit.toolkit_config_schema()
32
+ VectorStoreToolkit.toolkit_config_schema(),
33
+ SandboxToolkit.toolkit_config_schema(),
34
+ ImageGenerationToolkit.toolkit_config_schema(),
35
+ McpToolkit.toolkit_config_schema()
28
36
  ]
29
37
 
30
38
  return core_toolkits + community_toolkits() + alita_toolkits()
31
39
 
32
40
 
33
- def get_tools(tools_list: list, alita_client, llm, memory_store: BaseStore = None) -> list:
41
+ def get_tools(tools_list: list, alita_client, llm, memory_store: BaseStore = None, debug_mode: Optional[bool] = False, mcp_tokens: Optional[dict] = None) -> list:
34
42
  prompts = []
35
43
  tools = []
36
44
 
37
45
  for tool in tools_list:
38
- if tool['type'] == 'datasource':
39
- tools.extend(DatasourcesToolkit.get_toolkit(
40
- alita_client,
41
- datasource_ids=[int(tool['settings']['datasource_id'])],
42
- selected_tools=tool['settings']['selected_tools'],
43
- toolkit_name=tool.get('toolkit_name', '') or tool.get('name', '')
44
- ).get_tools())
45
- elif tool['type'] == 'application' and tool.get('agent_type', '') != 'pipeline' :
46
- tools.extend(ApplicationToolkit.get_toolkit(
47
- alita_client,
48
- application_id=int(tool['settings']['application_id']),
49
- application_version_id=int(tool['settings']['application_version_id']),
50
- selected_tools=[]
51
- ).get_tools())
52
- elif tool['type'] == 'application' and tool.get('agent_type', '') == 'pipeline':
53
- # static get_toolkit returns a list of CompiledStateGraph stubs
54
- tools.extend(SubgraphToolkit.get_toolkit(
55
- alita_client,
56
- application_id=int(tool['settings']['application_id']),
57
- application_version_id=int(tool['settings']['application_version_id']),
58
- app_api_key=alita_client.auth_token,
59
- selected_tools=[],
60
- llm=llm
61
- ))
62
- elif tool['type'] == 'memory':
63
- tools += MemoryToolkit.get_toolkit(
64
- namespace=tool['settings'].get('namespace', str(tool['id'])),
65
- pgvector_configuration=tool['settings'].get('pgvector_configuration', {}),
66
- store=memory_store,
67
- ).get_tools()
68
- elif tool['type'] == 'artifact':
69
- tools.extend(ArtifactToolkit.get_toolkit(
70
- client=alita_client,
71
- bucket=tool['settings']['bucket'],
72
- toolkit_name=tool.get('toolkit_name', ''),
73
- selected_tools=tool['settings'].get('selected_tools', []),
74
- llm=llm,
75
- # indexer settings
76
- pgvector_configuration=tool['settings'].get('pgvector_configuration', {}),
77
- embedding_model=tool['settings'].get('embedding_model'),
78
- collection_name=f"{tool.get('toolkit_name')}",
79
- ).get_tools())
80
- elif tool['type'] == 'vectorstore':
81
- tools.extend(VectorStoreToolkit.get_toolkit(
82
- llm=llm,
83
- toolkit_name=tool.get('toolkit_name', ''),
84
- **tool['settings']).get_tools())
46
+ try:
47
+ if tool['type'] == 'datasource':
48
+ tools.extend(DatasourcesToolkit.get_toolkit(
49
+ alita_client,
50
+ datasource_ids=[int(tool['settings']['datasource_id'])],
51
+ selected_tools=tool['settings']['selected_tools'],
52
+ toolkit_name=tool.get('toolkit_name', '') or tool.get('name', '')
53
+ ).get_tools())
54
+ elif tool['type'] == 'application':
55
+ tools.extend(ApplicationToolkit.get_toolkit(
56
+ alita_client,
57
+ application_id=int(tool['settings']['application_id']),
58
+ application_version_id=int(tool['settings']['application_version_id']),
59
+ selected_tools=[]
60
+ ).get_tools())
61
+ # backward compatibility for pipeline application type as subgraph node
62
+ if tool.get('agent_type', '') == 'pipeline':
63
+ # static get_toolkit returns a list of CompiledStateGraph stubs
64
+ tools.extend(SubgraphToolkit.get_toolkit(
65
+ alita_client,
66
+ application_id=int(tool['settings']['application_id']),
67
+ application_version_id=int(tool['settings']['application_version_id']),
68
+ app_api_key=alita_client.auth_token,
69
+ selected_tools=[],
70
+ llm=llm
71
+ ))
72
+ elif tool['type'] == 'memory':
73
+ tools += MemoryToolkit.get_toolkit(
74
+ namespace=tool['settings'].get('namespace', str(tool['id'])),
75
+ pgvector_configuration=tool['settings'].get('pgvector_configuration', {}),
76
+ store=memory_store,
77
+ ).get_tools()
78
+ # TODO: update configuration of internal tools
79
+ elif tool['type'] == 'internal_tool':
80
+ if tool['name'] == 'pyodide':
81
+ tools += SandboxToolkit.get_toolkit(
82
+ stateful=False,
83
+ allow_net=True,
84
+ alita_client=alita_client,
85
+ ).get_tools()
86
+ elif tool['name'] == 'image_generation':
87
+ if alita_client and alita_client.model_image_generation:
88
+ tools += ImageGenerationToolkit.get_toolkit(
89
+ client=alita_client,
90
+ ).get_tools()
91
+ else:
92
+ logger.warning("Image generation internal tool requested "
93
+ "but no image generation model configured")
94
+ elif tool['type'] == 'artifact':
95
+ tools.extend(ArtifactToolkit.get_toolkit(
96
+ client=alita_client,
97
+ bucket=tool['settings']['bucket'],
98
+ toolkit_name=tool.get('toolkit_name', ''),
99
+ selected_tools=tool['settings'].get('selected_tools', []),
100
+ llm=llm,
101
+ # indexer settings
102
+ pgvector_configuration=tool['settings'].get('pgvector_configuration', {}),
103
+ embedding_model=tool['settings'].get('embedding_model'),
104
+ collection_name=f"{tool.get('toolkit_name')}",
105
+ collection_schema = str(tool['id'])
106
+ ).get_tools())
107
+ elif tool['type'] == 'vectorstore':
108
+ tools.extend(VectorStoreToolkit.get_toolkit(
109
+ llm=llm,
110
+ toolkit_name=tool.get('toolkit_name', ''),
111
+ **tool['settings']).get_tools())
112
+ elif tool['type'] == 'mcp':
113
+ # remote mcp tool initialization with token injection
114
+ settings = dict(tool['settings'])
115
+ url = settings.get('url')
116
+ headers = settings.get('headers')
117
+ token_data = None
118
+ session_id = None
119
+ if mcp_tokens and url:
120
+ canonical_url = canonical_resource(url)
121
+ logger.info(f"[MCP Auth] Looking for token for URL: {url}")
122
+ logger.info(f"[MCP Auth] Canonical URL: {canonical_url}")
123
+ logger.info(f"[MCP Auth] Available tokens: {list(mcp_tokens.keys())}")
124
+ token_data = mcp_tokens.get(canonical_url)
125
+ if token_data:
126
+ logger.info(f"[MCP Auth] Found token data for {canonical_url}")
127
+ # Handle both old format (string) and new format (dict with access_token and session_id)
128
+ if isinstance(token_data, dict):
129
+ access_token = token_data.get('access_token')
130
+ session_id = token_data.get('session_id')
131
+ logger.info(f"[MCP Auth] Token data: access_token={'present' if access_token else 'missing'}, session_id={session_id or 'none'}")
132
+ else:
133
+ # Backward compatibility: treat as plain token string
134
+ access_token = token_data
135
+ logger.info(f"[MCP Auth] Using legacy token format (string)")
136
+ else:
137
+ access_token = None
138
+ logger.warning(f"[MCP Auth] No token found for {canonical_url}")
139
+ else:
140
+ access_token = None
141
+
142
+ if access_token:
143
+ merged_headers = dict(headers) if headers else {}
144
+ merged_headers.setdefault('Authorization', f'Bearer {access_token}')
145
+ settings['headers'] = merged_headers
146
+ logger.info(f"[MCP Auth] Added Authorization header for {url}")
147
+
148
+ # Pass session_id to MCP toolkit if available
149
+ if session_id:
150
+ settings['session_id'] = session_id
151
+ logger.info(f"[MCP Auth] Passing session_id to toolkit: {session_id}")
152
+ tools.extend(McpToolkit.get_toolkit(
153
+ toolkit_name=tool.get('toolkit_name', ''),
154
+ client=alita_client,
155
+ **settings).get_tools())
156
+ except Exception as e:
157
+ if isinstance(e, McpAuthorizationRequired):
158
+ raise
159
+ logger.error(f"Error initializing toolkit for tool '{tool.get('name', 'unknown')}': {e}", exc_info=True)
160
+ if debug_mode:
161
+ logger.info("Skipping tool initialization error due to debug mode.")
162
+ continue
163
+ else:
164
+ raise ToolException(f"Error initializing toolkit for tool '{tool.get('name', 'unknown')}': {e}")
85
165
 
86
166
  if len(prompts) > 0:
87
167
  tools += PromptToolkit.get_toolkit(alita_client, prompts).get_tools()
@@ -90,7 +170,8 @@ def get_tools(tools_list: list, alita_client, llm, memory_store: BaseStore = Non
90
170
  tools += community_tools(tools_list, alita_client, llm)
91
171
  # Add alita tools
92
172
  tools += alita_tools(tools_list, alita_client, llm, memory_store)
93
- # Add MCP tools
173
+ # Add MCP tools registered via alita-mcp CLI (static registry)
174
+ # Note: Tools with type='mcp' are already handled in main loop above
94
175
  tools += _mcp_tools(tools_list, alita_client)
95
176
 
96
177
  # Sanitize tool names to meet OpenAI's function naming requirements
@@ -145,6 +226,10 @@ def _sanitize_tool_names(tools: list) -> list:
145
226
 
146
227
 
147
228
  def _mcp_tools(tools_list, alita):
229
+ """
230
+ Handle MCP tools registered via alita-mcp CLI (static registry).
231
+ Skips tools with type='mcp' as those are handled by dynamic discovery.
232
+ """
148
233
  try:
149
234
  all_available_toolkits = alita.get_mcp_toolkits()
150
235
  toolkit_lookup = {tk["name"]: tk for tk in all_available_toolkits}
@@ -152,6 +237,11 @@ def _mcp_tools(tools_list, alita):
152
237
  #
153
238
  for selected_toolkit in tools_list:
154
239
  server_toolkit_name = selected_toolkit['type']
240
+
241
+ # Skip tools with type='mcp' - they're handled by dynamic discovery
242
+ if server_toolkit_name == 'mcp':
243
+ continue
244
+
155
245
  toolkit_conf = toolkit_lookup.get(server_toolkit_name)
156
246
  #
157
247
  if not toolkit_conf:
@@ -5,7 +5,11 @@ This module provides various tools that can be used within LangGraph agents.
5
5
 
6
6
  from .sandbox import PyodideSandboxTool, StatefulPyodideSandboxTool, create_sandbox_tool
7
7
  from .echo import EchoTool
8
- from .image_generation import ImageGenerationTool, create_image_generation_tool
8
+ from .image_generation import (
9
+ ImageGenerationTool,
10
+ create_image_generation_tool,
11
+ ImageGenerationToolkit
12
+ )
9
13
 
10
14
  __all__ = [
11
15
  "PyodideSandboxTool",
@@ -13,5 +17,6 @@ __all__ = [
13
17
  "create_sandbox_tool",
14
18
  "EchoTool",
15
19
  "ImageGenerationTool",
20
+ "ImageGenerationToolkit",
16
21
  "create_image_generation_tool"
17
- ]
22
+ ]
@@ -50,6 +50,8 @@ class Application(BaseTool):
50
50
  application: Any
51
51
  args_schema: Type[BaseModel] = applicationToolSchema
52
52
  return_type: str = "str"
53
+ client: Any
54
+ args_runnable: dict = {}
53
55
 
54
56
  @field_validator('name', mode='before')
55
57
  @classmethod
@@ -66,6 +68,11 @@ class Application(BaseTool):
66
68
  return self._run(*config, **all_kwargs)
67
69
 
68
70
  def _run(self, *args, **kwargs):
71
+ if self.client and self.args_runnable:
72
+ # Recreate new LanggraphAgentRunnable in order to reflect the current input_mapping (it can be dynamic for pipelines).
73
+ # Actually, for pipelines agent toolkits LanggraphAgentRunnable is created (for LLMNode) before pipeline's schema parsing.
74
+ application_variables = {k: {"name": k, "value": v} for k, v in kwargs.items()}
75
+ self.application = self.client.application(**self.args_runnable, application_variables=application_variables)
69
76
  response = self.application.invoke(formulate_query(kwargs))
70
77
  if self.return_type == "str":
71
78
  return response["output"]
@@ -1,18 +1,33 @@
1
+ import json
1
2
  import logging
3
+ from copy import deepcopy
2
4
  from json import dumps
3
5
 
4
6
  from langchain_core.callbacks import dispatch_custom_event
5
7
  from langchain_core.messages import ToolCall
6
8
  from langchain_core.runnables import RunnableConfig
7
9
  from langchain_core.tools import BaseTool, ToolException
8
- from typing import Any, Optional, Union, Annotated
10
+ from typing import Any, Optional, Union
9
11
  from langchain_core.utils.function_calling import convert_to_openai_tool
10
12
  from pydantic import ValidationError
13
+
11
14
  from ..langchain.utils import propagate_the_input_mapping
12
15
 
13
16
  logger = logging.getLogger(__name__)
14
17
 
15
18
 
19
+ def replace_escaped_newlines(data):
20
+ """
21
+ Replace \\n with \n in all string values recursively.
22
+ Required for sanitization of state variables in code node
23
+ """
24
+ if isinstance(data, dict):
25
+ return {key: replace_escaped_newlines(value) for key, value in data.items()}
26
+ elif isinstance(data, str):
27
+ return data.replace('\\n', '\n')
28
+ else:
29
+ return data
30
+
16
31
  class FunctionTool(BaseTool):
17
32
  name: str = 'FunctionalTool'
18
33
  description: str = 'This is direct call node for tools'
@@ -21,6 +36,61 @@ class FunctionTool(BaseTool):
21
36
  input_variables: Optional[list[str]] = None
22
37
  input_mapping: Optional[dict[str, dict]] = None
23
38
  output_variables: Optional[list[str]] = None
39
+ structured_output: Optional[bool] = False
40
+ alita_client: Optional[Any] = None
41
+
42
+ def _prepare_pyodide_input(self, state: Union[str, dict, ToolCall]) -> str:
43
+ """Prepare input for PyodideSandboxTool by injecting state into the code block."""
44
+ # add state into the code block here since it might be changed during the execution of the code
45
+ state_copy = replace_escaped_newlines(deepcopy(state))
46
+
47
+ del state_copy['messages'] # remove messages to avoid issues with pickling without langchain-core
48
+ # inject state into the code block as alita_state variable
49
+ state_json = json.dumps(state_copy, ensure_ascii=False)
50
+ pyodide_predata = f'#state dict\nimport json\nalita_state = json.loads({json.dumps(state_json)})\n'
51
+
52
+ return pyodide_predata
53
+
54
+ def _handle_pyodide_output(self, tool_result: Any) -> dict:
55
+ """Handle output processing for PyodideSandboxTool results."""
56
+ tool_result_converted = {}
57
+
58
+ if self.output_variables:
59
+ for var in self.output_variables:
60
+ if var == "messages":
61
+ tool_result_converted.update(
62
+ {"messages": [{"role": "assistant", "content": dumps(tool_result)}]})
63
+ continue
64
+ if isinstance(tool_result, dict) and var in tool_result:
65
+ tool_result_converted[var] = tool_result[var]
66
+ else:
67
+ # handler in case user points to a var that is not in the output of the tool
68
+ tool_result_converted[var] = tool_result.get('result',
69
+ tool_result.get('error') if tool_result.get('error')
70
+ else 'Execution result is missing')
71
+ else:
72
+ tool_result_converted.update({"messages": [{"role": "assistant", "content": dumps(tool_result)}]})
73
+
74
+ if self.structured_output:
75
+ # execute code tool and update state variables
76
+ try:
77
+ result_value = tool_result.get('result', {})
78
+ if isinstance(result_value, dict):
79
+ tool_result_converted.update(result_value)
80
+ elif isinstance(result_value, list):
81
+ # Handle list case - could wrap in a key or handle differently based on requirements
82
+ tool_result_converted.update({"result": result_value})
83
+ else:
84
+ # Handle JSON string case
85
+ tool_result_converted.update(json.loads(result_value))
86
+ except json.JSONDecodeError:
87
+ logger.error(f"JSONDecodeError: {tool_result}")
88
+
89
+ return tool_result_converted
90
+
91
+ def _is_pyodide_tool(self) -> bool:
92
+ """Check if the current tool is a PyodideSandboxTool."""
93
+ return self.tool.name.lower() == 'pyodide_sandbox'
24
94
 
25
95
  def invoke(
26
96
  self,
@@ -31,8 +101,14 @@ class FunctionTool(BaseTool):
31
101
  params = convert_to_openai_tool(self.tool).get(
32
102
  'function', {'parameters': {}}).get(
33
103
  'parameters', {'properties': {}}).get('properties', {})
104
+
34
105
  func_args = propagate_the_input_mapping(input_mapping=self.input_mapping, input_variables=self.input_variables,
35
106
  state=state)
107
+
108
+ # special handler for PyodideSandboxTool
109
+ if self._is_pyodide_tool():
110
+ code = func_args['code']
111
+ func_args['code'] = f"{self._prepare_pyodide_input(state)}\n{code}"
36
112
  try:
37
113
  tool_result = self.tool.invoke(func_args, config, **kwargs)
38
114
  dispatch_custom_event(
@@ -44,17 +120,30 @@ class FunctionTool(BaseTool):
44
120
  }, config=config
45
121
  )
46
122
  logger.info(f"ToolNode response: {tool_result}")
123
+
124
+ # handler for PyodideSandboxTool
125
+ if self._is_pyodide_tool():
126
+ return self._handle_pyodide_output(tool_result)
127
+
47
128
  if not self.output_variables:
48
129
  return {"messages": [{"role": "assistant", "content": dumps(tool_result)}]}
49
130
  else:
50
- if self.output_variables[0] == "messages":
51
- return {
131
+ if "messages" in self.output_variables:
132
+ messages_dict = {
52
133
  "messages": [{
53
134
  "role": "assistant",
54
- "content": dumps(tool_result) if not isinstance(tool_result, ToolException) else str(
55
- tool_result)
135
+ "content": dumps(tool_result)
136
+ if not isinstance(tool_result, ToolException) and not isinstance(tool_result, str)
137
+ else str(tool_result)
56
138
  }]
57
139
  }
140
+ for var in self.output_variables:
141
+ if var != "messages":
142
+ if isinstance(tool_result, dict) and var in tool_result:
143
+ messages_dict[var] = tool_result[var]
144
+ else:
145
+ messages_dict[var] = tool_result
146
+ return messages_dict
58
147
  else:
59
148
  return { self.output_variables[0]: tool_result }
60
149
  except ValidationError:
@@ -47,8 +47,8 @@ def formulate_query(kwargs):
47
47
 
48
48
 
49
49
  class GraphTool(BaseTool):
50
- name: str
51
- description: str
50
+ name: str = 'GraphTool'
51
+ description: str = 'Graph tool for tools'
52
52
  graph: CompiledStateGraph
53
53
  args_schema: Type[BaseModel] = graphToolSchema
54
54
  return_type: str = "str"
@@ -65,10 +65,16 @@ class GraphTool(BaseTool):
65
65
  all_kwargs = {**kwargs, **extras, **schema_values}
66
66
  if config is None:
67
67
  config = {}
68
- return self._run(*config, **all_kwargs)
68
+ # Pass the config to the _run empty or the one passed from the parent executor.
69
+ return self._run(config, **all_kwargs)
69
70
 
70
71
  def _run(self, *args, **kwargs):
71
- response = self.graph.invoke(formulate_query(kwargs))
72
+ config = None
73
+ # From invoke method we are passing only 1 arg so it is safe to do this condition and config assignment.
74
+ # Default to None is safe because it will be checked also on the langchain side.
75
+ if args:
76
+ config = args[0]
77
+ response = self.graph.invoke(formulate_query(kwargs), config=config)
72
78
  if self.return_type == "str":
73
79
  return response["output"]
74
80
  else:
@@ -2,16 +2,59 @@
2
2
  Image generation tool for Alita SDK.
3
3
  """
4
4
  import logging
5
- from typing import Optional, Type, Any
6
- from langchain_core.tools import BaseTool
7
- from pydantic import BaseModel, Field
5
+ from typing import Optional, Type, Any, List, Literal
6
+ from langchain_core.tools import BaseTool, BaseToolkit
7
+ from pydantic import BaseModel, Field, create_model, ConfigDict
8
8
 
9
9
  logger = logging.getLogger(__name__)
10
10
 
11
+ name = "image_generation"
12
+
13
+
14
+ def get_tools(tools_list: list, alita_client=None, llm=None,
15
+ memory_store=None):
16
+ """
17
+ Get image generation tools for the provided tool configurations.
18
+
19
+ Args:
20
+ tools_list: List of tool configurations
21
+ alita_client: Alita client instance (required for image generation)
22
+ llm: LLM client instance (unused for image generation)
23
+ memory_store: Optional memory store instance (unused)
24
+
25
+ Returns:
26
+ List of image generation tools
27
+ """
28
+ all_tools = []
29
+
30
+ for tool in tools_list:
31
+ if (tool.get('type') == 'image_generation' or
32
+ tool.get('toolkit_name') == 'image_generation'):
33
+ try:
34
+ if not alita_client:
35
+ logger.error("Alita client is required for image "
36
+ "generation tools")
37
+ continue
38
+
39
+ toolkit_instance = ImageGenerationToolkit.get_toolkit(
40
+ client=alita_client,
41
+ toolkit_name=tool.get('toolkit_name', '')
42
+ )
43
+ all_tools.extend(toolkit_instance.get_tools())
44
+ except Exception as e:
45
+ logger.error(f"Error in image generation toolkit "
46
+ f"get_tools: {e}")
47
+ logger.error(f"Tool config: {tool}")
48
+ raise
49
+
50
+ return all_tools
51
+
11
52
 
12
53
  class ImageGenerationInput(BaseModel):
13
54
  """Input schema for image generation tool."""
14
- prompt: str = Field(description="Text prompt describing the image to generate")
55
+ prompt: str = Field(
56
+ description="Text prompt describing the image to generate"
57
+ )
15
58
  n: int = Field(
16
59
  default=1, description="Number of images to generate (1-10)",
17
60
  ge=1, le=10
@@ -22,7 +65,7 @@ class ImageGenerationInput(BaseModel):
22
65
  )
23
66
  quality: str = Field(
24
67
  default="auto",
25
- description="Quality of the generated image ('low', 'medium', 'high', 'auto')"
68
+ description="Quality of the generated image ('low', 'medium', 'high')"
26
69
  )
27
70
  style: Optional[str] = Field(
28
71
  default=None, description="Style of the generated image (optional)"
@@ -69,7 +112,8 @@ class ImageGenerationTool(BaseTool):
69
112
  else:
70
113
  content_chunks.append({
71
114
  "type": "text",
72
- "text": f"Generated {len(images)} images for prompt: '{prompt}'"
115
+ "text": f"Generated {len(images)} images for "
116
+ f"prompt: '{prompt}'"
73
117
  })
74
118
 
75
119
  # Add image content for each generated image
@@ -85,7 +129,8 @@ class ImageGenerationTool(BaseTool):
85
129
  content_chunks.append({
86
130
  "type": "image_url",
87
131
  "image_url": {
88
- "url": f"data:image/png;base64,{image_data['b64_json']}"
132
+ "url": f"data:image/png;base64,"
133
+ f"{image_data['b64_json']}"
89
134
  }
90
135
  })
91
136
 
@@ -94,7 +139,8 @@ class ImageGenerationTool(BaseTool):
94
139
  # Fallback to text response if no images in result
95
140
  return [{
96
141
  "type": "text",
97
- "text": f"Image generation completed but no images returned: {result}"
142
+ "text": f"Image generation completed but no images "
143
+ f"returned: {result}"
98
144
  }]
99
145
 
100
146
  except Exception as e:
@@ -114,3 +160,53 @@ class ImageGenerationTool(BaseTool):
114
160
  def create_image_generation_tool(client):
115
161
  """Create an image generation tool with the provided Alita client."""
116
162
  return ImageGenerationTool(client=client)
163
+
164
+
165
+ class ImageGenerationToolkit(BaseToolkit):
166
+ """Toolkit for image generation tools."""
167
+ tools: List[BaseTool] = []
168
+
169
+ @staticmethod
170
+ def toolkit_config_schema() -> BaseModel:
171
+ """Get the configuration schema for the image generation toolkit."""
172
+ # Create sample tool to get schema
173
+ sample_tool = ImageGenerationTool(client=None)
174
+ selected_tools = {sample_tool.name: sample_tool.args_schema.schema()}
175
+
176
+ return create_model(
177
+ 'image_generation',
178
+ selected_tools=(
179
+ List[Literal[tuple(selected_tools)]],
180
+ Field(
181
+ default=[],
182
+ json_schema_extra={'args_schemas': selected_tools}
183
+ )
184
+ ),
185
+ __config__=ConfigDict(json_schema_extra={
186
+ 'metadata': {
187
+ "label": "Image Generation",
188
+ "icon_url": "image_generation.svg",
189
+ "hidden": True,
190
+ "categories": ["internal_tool"],
191
+ "extra_categories": ["image generation"],
192
+ }
193
+ })
194
+ )
195
+
196
+ @classmethod
197
+ def get_toolkit(cls, client=None, **kwargs):
198
+ """
199
+ Get toolkit with image generation tools.
200
+
201
+ Args:
202
+ client: Alita client instance (required)
203
+ **kwargs: Additional arguments
204
+ """
205
+ if not client:
206
+ raise ValueError("Alita client is required for image generation")
207
+
208
+ tools = [ImageGenerationTool(client=client)]
209
+ return cls(tools=tools)
210
+
211
+ def get_tools(self):
212
+ return self.tools