alita-sdk 0.3.379__py3-none-any.whl → 0.3.462__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of alita-sdk might be problematic. Click here for more details.
- alita_sdk/cli/__init__.py +10 -0
- alita_sdk/cli/__main__.py +17 -0
- alita_sdk/cli/agent_executor.py +144 -0
- alita_sdk/cli/agent_loader.py +197 -0
- alita_sdk/cli/agent_ui.py +166 -0
- alita_sdk/cli/agents.py +1069 -0
- alita_sdk/cli/callbacks.py +576 -0
- alita_sdk/cli/cli.py +159 -0
- alita_sdk/cli/config.py +153 -0
- alita_sdk/cli/formatting.py +182 -0
- alita_sdk/cli/mcp_loader.py +315 -0
- alita_sdk/cli/toolkit.py +330 -0
- alita_sdk/cli/toolkit_loader.py +55 -0
- alita_sdk/cli/tools/__init__.py +9 -0
- alita_sdk/cli/tools/filesystem.py +905 -0
- alita_sdk/configurations/bitbucket.py +95 -0
- alita_sdk/configurations/confluence.py +96 -1
- alita_sdk/configurations/gitlab.py +79 -0
- alita_sdk/configurations/jira.py +103 -0
- alita_sdk/configurations/testrail.py +88 -0
- alita_sdk/configurations/xray.py +93 -0
- alita_sdk/configurations/zephyr_enterprise.py +93 -0
- alita_sdk/configurations/zephyr_essential.py +75 -0
- alita_sdk/runtime/clients/client.py +47 -10
- alita_sdk/runtime/clients/mcp_discovery.py +342 -0
- alita_sdk/runtime/clients/mcp_manager.py +262 -0
- alita_sdk/runtime/clients/sandbox_client.py +8 -0
- alita_sdk/runtime/langchain/assistant.py +37 -16
- alita_sdk/runtime/langchain/constants.py +6 -1
- alita_sdk/runtime/langchain/document_loaders/AlitaDocxMammothLoader.py +315 -3
- alita_sdk/runtime/langchain/document_loaders/AlitaJSONLoader.py +4 -1
- alita_sdk/runtime/langchain/document_loaders/constants.py +28 -12
- alita_sdk/runtime/langchain/langraph_agent.py +146 -31
- alita_sdk/runtime/langchain/utils.py +39 -7
- alita_sdk/runtime/models/mcp_models.py +61 -0
- alita_sdk/runtime/toolkits/__init__.py +24 -0
- alita_sdk/runtime/toolkits/application.py +8 -1
- alita_sdk/runtime/toolkits/artifact.py +5 -6
- alita_sdk/runtime/toolkits/mcp.py +895 -0
- alita_sdk/runtime/toolkits/tools.py +137 -56
- alita_sdk/runtime/tools/__init__.py +7 -2
- alita_sdk/runtime/tools/application.py +7 -0
- alita_sdk/runtime/tools/function.py +29 -25
- alita_sdk/runtime/tools/graph.py +10 -4
- alita_sdk/runtime/tools/image_generation.py +104 -8
- alita_sdk/runtime/tools/llm.py +204 -114
- alita_sdk/runtime/tools/mcp_inspect_tool.py +284 -0
- alita_sdk/runtime/tools/mcp_remote_tool.py +166 -0
- alita_sdk/runtime/tools/mcp_server_tool.py +3 -1
- alita_sdk/runtime/tools/sandbox.py +57 -43
- alita_sdk/runtime/tools/vectorstore.py +2 -1
- alita_sdk/runtime/tools/vectorstore_base.py +19 -3
- alita_sdk/runtime/utils/mcp_oauth.py +164 -0
- alita_sdk/runtime/utils/mcp_sse_client.py +405 -0
- alita_sdk/runtime/utils/streamlit.py +34 -3
- alita_sdk/runtime/utils/toolkit_utils.py +14 -4
- alita_sdk/tools/__init__.py +46 -31
- alita_sdk/tools/ado/repos/__init__.py +1 -0
- alita_sdk/tools/ado/test_plan/__init__.py +1 -1
- alita_sdk/tools/ado/wiki/__init__.py +1 -5
- alita_sdk/tools/ado/work_item/__init__.py +1 -5
- alita_sdk/tools/ado/work_item/ado_wrapper.py +17 -8
- alita_sdk/tools/base_indexer_toolkit.py +105 -43
- alita_sdk/tools/bitbucket/__init__.py +1 -0
- alita_sdk/tools/chunkers/sematic/proposal_chunker.py +1 -1
- alita_sdk/tools/code/sonar/__init__.py +1 -1
- alita_sdk/tools/code_indexer_toolkit.py +13 -3
- alita_sdk/tools/confluence/__init__.py +2 -2
- alita_sdk/tools/confluence/api_wrapper.py +29 -7
- alita_sdk/tools/confluence/loader.py +10 -0
- alita_sdk/tools/github/__init__.py +2 -2
- alita_sdk/tools/gitlab/__init__.py +2 -1
- alita_sdk/tools/gitlab/api_wrapper.py +11 -7
- alita_sdk/tools/gitlab_org/__init__.py +1 -2
- alita_sdk/tools/google_places/__init__.py +2 -1
- alita_sdk/tools/jira/__init__.py +1 -0
- alita_sdk/tools/jira/api_wrapper.py +1 -1
- alita_sdk/tools/memory/__init__.py +1 -1
- alita_sdk/tools/openapi/__init__.py +10 -1
- alita_sdk/tools/pandas/__init__.py +1 -1
- alita_sdk/tools/postman/__init__.py +2 -1
- alita_sdk/tools/pptx/__init__.py +2 -2
- alita_sdk/tools/qtest/__init__.py +3 -3
- alita_sdk/tools/qtest/api_wrapper.py +1708 -76
- alita_sdk/tools/rally/__init__.py +1 -2
- alita_sdk/tools/report_portal/__init__.py +1 -0
- alita_sdk/tools/salesforce/__init__.py +1 -0
- alita_sdk/tools/servicenow/__init__.py +2 -3
- alita_sdk/tools/sharepoint/__init__.py +1 -0
- alita_sdk/tools/sharepoint/api_wrapper.py +125 -34
- alita_sdk/tools/sharepoint/authorization_helper.py +191 -1
- alita_sdk/tools/sharepoint/utils.py +8 -2
- alita_sdk/tools/slack/__init__.py +1 -0
- alita_sdk/tools/sql/__init__.py +2 -1
- alita_sdk/tools/testio/__init__.py +1 -0
- alita_sdk/tools/testrail/__init__.py +1 -3
- alita_sdk/tools/utils/content_parser.py +27 -16
- alita_sdk/tools/vector_adapters/VectorStoreAdapter.py +18 -5
- alita_sdk/tools/xray/__init__.py +2 -1
- alita_sdk/tools/zephyr/__init__.py +2 -1
- alita_sdk/tools/zephyr_enterprise/__init__.py +1 -0
- alita_sdk/tools/zephyr_essential/__init__.py +1 -0
- alita_sdk/tools/zephyr_scale/__init__.py +1 -0
- alita_sdk/tools/zephyr_squad/__init__.py +1 -0
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/METADATA +8 -2
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/RECORD +110 -86
- alita_sdk-0.3.462.dist-info/entry_points.txt +2 -0
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
from typing import Optional
|
|
2
3
|
|
|
3
4
|
from langchain_core.tools import ToolException
|
|
4
5
|
from langgraph.store.base import BaseStore
|
|
@@ -11,11 +12,14 @@ from .datasource import DatasourcesToolkit
|
|
|
11
12
|
from .prompt import PromptToolkit
|
|
12
13
|
from .subgraph import SubgraphToolkit
|
|
13
14
|
from .vectorstore import VectorStoreToolkit
|
|
15
|
+
from .mcp import McpToolkit
|
|
14
16
|
from ..tools.mcp_server_tool import McpServerTool
|
|
15
17
|
from ..tools.sandbox import SandboxToolkit
|
|
18
|
+
from ..tools.image_generation import ImageGenerationToolkit
|
|
16
19
|
# Import community tools
|
|
17
20
|
from ...community import get_toolkits as community_toolkits, get_tools as community_tools
|
|
18
21
|
from ...tools.memory import MemoryToolkit
|
|
22
|
+
from ..utils.mcp_oauth import canonical_resource, McpAuthorizationRequired
|
|
19
23
|
from ...tools.utils import TOOLKIT_SPLITTER
|
|
20
24
|
|
|
21
25
|
logger = logging.getLogger(__name__)
|
|
@@ -26,71 +30,138 @@ def get_toolkits():
|
|
|
26
30
|
ArtifactToolkit.toolkit_config_schema(),
|
|
27
31
|
MemoryToolkit.toolkit_config_schema(),
|
|
28
32
|
VectorStoreToolkit.toolkit_config_schema(),
|
|
29
|
-
SandboxToolkit.toolkit_config_schema()
|
|
33
|
+
SandboxToolkit.toolkit_config_schema(),
|
|
34
|
+
ImageGenerationToolkit.toolkit_config_schema(),
|
|
35
|
+
McpToolkit.toolkit_config_schema()
|
|
30
36
|
]
|
|
31
37
|
|
|
32
38
|
return core_toolkits + community_toolkits() + alita_toolkits()
|
|
33
39
|
|
|
34
40
|
|
|
35
|
-
def get_tools(tools_list: list, alita_client, llm, memory_store: BaseStore = None) -> list:
|
|
41
|
+
def get_tools(tools_list: list, alita_client, llm, memory_store: BaseStore = None, debug_mode: Optional[bool] = False, mcp_tokens: Optional[dict] = None) -> list:
|
|
36
42
|
prompts = []
|
|
37
43
|
tools = []
|
|
38
44
|
|
|
39
45
|
for tool in tools_list:
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
elif tool['type'] == 'internal_tool':
|
|
72
|
-
if tool['name'] == 'pyodide':
|
|
73
|
-
tools += SandboxToolkit.get_toolkit(
|
|
74
|
-
stateful=False,
|
|
75
|
-
allow_net=True,
|
|
46
|
+
try:
|
|
47
|
+
if tool['type'] == 'datasource':
|
|
48
|
+
tools.extend(DatasourcesToolkit.get_toolkit(
|
|
49
|
+
alita_client,
|
|
50
|
+
datasource_ids=[int(tool['settings']['datasource_id'])],
|
|
51
|
+
selected_tools=tool['settings']['selected_tools'],
|
|
52
|
+
toolkit_name=tool.get('toolkit_name', '') or tool.get('name', '')
|
|
53
|
+
).get_tools())
|
|
54
|
+
elif tool['type'] == 'application':
|
|
55
|
+
tools.extend(ApplicationToolkit.get_toolkit(
|
|
56
|
+
alita_client,
|
|
57
|
+
application_id=int(tool['settings']['application_id']),
|
|
58
|
+
application_version_id=int(tool['settings']['application_version_id']),
|
|
59
|
+
selected_tools=[]
|
|
60
|
+
).get_tools())
|
|
61
|
+
# backward compatibility for pipeline application type as subgraph node
|
|
62
|
+
if tool.get('agent_type', '') == 'pipeline':
|
|
63
|
+
# static get_toolkit returns a list of CompiledStateGraph stubs
|
|
64
|
+
tools.extend(SubgraphToolkit.get_toolkit(
|
|
65
|
+
alita_client,
|
|
66
|
+
application_id=int(tool['settings']['application_id']),
|
|
67
|
+
application_version_id=int(tool['settings']['application_version_id']),
|
|
68
|
+
app_api_key=alita_client.auth_token,
|
|
69
|
+
selected_tools=[],
|
|
70
|
+
llm=llm
|
|
71
|
+
))
|
|
72
|
+
elif tool['type'] == 'memory':
|
|
73
|
+
tools += MemoryToolkit.get_toolkit(
|
|
74
|
+
namespace=tool['settings'].get('namespace', str(tool['id'])),
|
|
75
|
+
pgvector_configuration=tool['settings'].get('pgvector_configuration', {}),
|
|
76
|
+
store=memory_store,
|
|
76
77
|
).get_tools()
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
78
|
+
# TODO: update configuration of internal tools
|
|
79
|
+
elif tool['type'] == 'internal_tool':
|
|
80
|
+
if tool['name'] == 'pyodide':
|
|
81
|
+
tools += SandboxToolkit.get_toolkit(
|
|
82
|
+
stateful=False,
|
|
83
|
+
allow_net=True,
|
|
84
|
+
alita_client=alita_client,
|
|
85
|
+
).get_tools()
|
|
86
|
+
elif tool['name'] == 'image_generation':
|
|
87
|
+
if alita_client and alita_client.model_image_generation:
|
|
88
|
+
tools += ImageGenerationToolkit.get_toolkit(
|
|
89
|
+
client=alita_client,
|
|
90
|
+
).get_tools()
|
|
91
|
+
else:
|
|
92
|
+
logger.warning("Image generation internal tool requested "
|
|
93
|
+
"but no image generation model configured")
|
|
94
|
+
elif tool['type'] == 'artifact':
|
|
95
|
+
tools.extend(ArtifactToolkit.get_toolkit(
|
|
96
|
+
client=alita_client,
|
|
97
|
+
bucket=tool['settings']['bucket'],
|
|
98
|
+
toolkit_name=tool.get('toolkit_name', ''),
|
|
99
|
+
selected_tools=tool['settings'].get('selected_tools', []),
|
|
100
|
+
llm=llm,
|
|
101
|
+
# indexer settings
|
|
102
|
+
pgvector_configuration=tool['settings'].get('pgvector_configuration', {}),
|
|
103
|
+
embedding_model=tool['settings'].get('embedding_model'),
|
|
104
|
+
collection_name=f"{tool.get('toolkit_name')}",
|
|
105
|
+
collection_schema = str(tool['id'])
|
|
106
|
+
).get_tools())
|
|
107
|
+
elif tool['type'] == 'vectorstore':
|
|
108
|
+
tools.extend(VectorStoreToolkit.get_toolkit(
|
|
109
|
+
llm=llm,
|
|
110
|
+
toolkit_name=tool.get('toolkit_name', ''),
|
|
111
|
+
**tool['settings']).get_tools())
|
|
112
|
+
elif tool['type'] == 'mcp':
|
|
113
|
+
# remote mcp tool initialization with token injection
|
|
114
|
+
settings = dict(tool['settings'])
|
|
115
|
+
url = settings.get('url')
|
|
116
|
+
headers = settings.get('headers')
|
|
117
|
+
token_data = None
|
|
118
|
+
session_id = None
|
|
119
|
+
if mcp_tokens and url:
|
|
120
|
+
canonical_url = canonical_resource(url)
|
|
121
|
+
logger.info(f"[MCP Auth] Looking for token for URL: {url}")
|
|
122
|
+
logger.info(f"[MCP Auth] Canonical URL: {canonical_url}")
|
|
123
|
+
logger.info(f"[MCP Auth] Available tokens: {list(mcp_tokens.keys())}")
|
|
124
|
+
token_data = mcp_tokens.get(canonical_url)
|
|
125
|
+
if token_data:
|
|
126
|
+
logger.info(f"[MCP Auth] Found token data for {canonical_url}")
|
|
127
|
+
# Handle both old format (string) and new format (dict with access_token and session_id)
|
|
128
|
+
if isinstance(token_data, dict):
|
|
129
|
+
access_token = token_data.get('access_token')
|
|
130
|
+
session_id = token_data.get('session_id')
|
|
131
|
+
logger.info(f"[MCP Auth] Token data: access_token={'present' if access_token else 'missing'}, session_id={session_id or 'none'}")
|
|
132
|
+
else:
|
|
133
|
+
# Backward compatibility: treat as plain token string
|
|
134
|
+
access_token = token_data
|
|
135
|
+
logger.info(f"[MCP Auth] Using legacy token format (string)")
|
|
136
|
+
else:
|
|
137
|
+
access_token = None
|
|
138
|
+
logger.warning(f"[MCP Auth] No token found for {canonical_url}")
|
|
139
|
+
else:
|
|
140
|
+
access_token = None
|
|
141
|
+
|
|
142
|
+
if access_token:
|
|
143
|
+
merged_headers = dict(headers) if headers else {}
|
|
144
|
+
merged_headers.setdefault('Authorization', f'Bearer {access_token}')
|
|
145
|
+
settings['headers'] = merged_headers
|
|
146
|
+
logger.info(f"[MCP Auth] Added Authorization header for {url}")
|
|
147
|
+
|
|
148
|
+
# Pass session_id to MCP toolkit if available
|
|
149
|
+
if session_id:
|
|
150
|
+
settings['session_id'] = session_id
|
|
151
|
+
logger.info(f"[MCP Auth] Passing session_id to toolkit: {session_id}")
|
|
152
|
+
tools.extend(McpToolkit.get_toolkit(
|
|
153
|
+
toolkit_name=tool.get('toolkit_name', ''),
|
|
154
|
+
client=alita_client,
|
|
155
|
+
**settings).get_tools())
|
|
156
|
+
except Exception as e:
|
|
157
|
+
if isinstance(e, McpAuthorizationRequired):
|
|
158
|
+
raise
|
|
159
|
+
logger.error(f"Error initializing toolkit for tool '{tool.get('name', 'unknown')}': {e}", exc_info=True)
|
|
160
|
+
if debug_mode:
|
|
161
|
+
logger.info("Skipping tool initialization error due to debug mode.")
|
|
162
|
+
continue
|
|
163
|
+
else:
|
|
164
|
+
raise ToolException(f"Error initializing toolkit for tool '{tool.get('name', 'unknown')}': {e}")
|
|
94
165
|
|
|
95
166
|
if len(prompts) > 0:
|
|
96
167
|
tools += PromptToolkit.get_toolkit(alita_client, prompts).get_tools()
|
|
@@ -99,7 +170,8 @@ def get_tools(tools_list: list, alita_client, llm, memory_store: BaseStore = Non
|
|
|
99
170
|
tools += community_tools(tools_list, alita_client, llm)
|
|
100
171
|
# Add alita tools
|
|
101
172
|
tools += alita_tools(tools_list, alita_client, llm, memory_store)
|
|
102
|
-
# Add MCP tools
|
|
173
|
+
# Add MCP tools registered via alita-mcp CLI (static registry)
|
|
174
|
+
# Note: Tools with type='mcp' are already handled in main loop above
|
|
103
175
|
tools += _mcp_tools(tools_list, alita_client)
|
|
104
176
|
|
|
105
177
|
# Sanitize tool names to meet OpenAI's function naming requirements
|
|
@@ -154,6 +226,10 @@ def _sanitize_tool_names(tools: list) -> list:
|
|
|
154
226
|
|
|
155
227
|
|
|
156
228
|
def _mcp_tools(tools_list, alita):
|
|
229
|
+
"""
|
|
230
|
+
Handle MCP tools registered via alita-mcp CLI (static registry).
|
|
231
|
+
Skips tools with type='mcp' as those are handled by dynamic discovery.
|
|
232
|
+
"""
|
|
157
233
|
try:
|
|
158
234
|
all_available_toolkits = alita.get_mcp_toolkits()
|
|
159
235
|
toolkit_lookup = {tk["name"]: tk for tk in all_available_toolkits}
|
|
@@ -161,6 +237,11 @@ def _mcp_tools(tools_list, alita):
|
|
|
161
237
|
#
|
|
162
238
|
for selected_toolkit in tools_list:
|
|
163
239
|
server_toolkit_name = selected_toolkit['type']
|
|
240
|
+
|
|
241
|
+
# Skip tools with type='mcp' - they're handled by dynamic discovery
|
|
242
|
+
if server_toolkit_name == 'mcp':
|
|
243
|
+
continue
|
|
244
|
+
|
|
164
245
|
toolkit_conf = toolkit_lookup.get(server_toolkit_name)
|
|
165
246
|
#
|
|
166
247
|
if not toolkit_conf:
|
|
@@ -5,7 +5,11 @@ This module provides various tools that can be used within LangGraph agents.
|
|
|
5
5
|
|
|
6
6
|
from .sandbox import PyodideSandboxTool, StatefulPyodideSandboxTool, create_sandbox_tool
|
|
7
7
|
from .echo import EchoTool
|
|
8
|
-
from .image_generation import
|
|
8
|
+
from .image_generation import (
|
|
9
|
+
ImageGenerationTool,
|
|
10
|
+
create_image_generation_tool,
|
|
11
|
+
ImageGenerationToolkit
|
|
12
|
+
)
|
|
9
13
|
|
|
10
14
|
__all__ = [
|
|
11
15
|
"PyodideSandboxTool",
|
|
@@ -13,5 +17,6 @@ __all__ = [
|
|
|
13
17
|
"create_sandbox_tool",
|
|
14
18
|
"EchoTool",
|
|
15
19
|
"ImageGenerationTool",
|
|
20
|
+
"ImageGenerationToolkit",
|
|
16
21
|
"create_image_generation_tool"
|
|
17
|
-
]
|
|
22
|
+
]
|
|
@@ -50,6 +50,8 @@ class Application(BaseTool):
|
|
|
50
50
|
application: Any
|
|
51
51
|
args_schema: Type[BaseModel] = applicationToolSchema
|
|
52
52
|
return_type: str = "str"
|
|
53
|
+
client: Any
|
|
54
|
+
args_runnable: dict = {}
|
|
53
55
|
|
|
54
56
|
@field_validator('name', mode='before')
|
|
55
57
|
@classmethod
|
|
@@ -66,6 +68,11 @@ class Application(BaseTool):
|
|
|
66
68
|
return self._run(*config, **all_kwargs)
|
|
67
69
|
|
|
68
70
|
def _run(self, *args, **kwargs):
|
|
71
|
+
if self.client and self.args_runnable:
|
|
72
|
+
# Recreate new LanggraphAgentRunnable in order to reflect the current input_mapping (it can be dynamic for pipelines).
|
|
73
|
+
# Actually, for pipelines agent toolkits LanggraphAgentRunnable is created (for LLMNode) before pipeline's schema parsing.
|
|
74
|
+
application_variables = {k: {"name": k, "value": v} for k, v in kwargs.items()}
|
|
75
|
+
self.application = self.client.application(**self.args_runnable, application_variables=application_variables)
|
|
69
76
|
response = self.application.invoke(formulate_query(kwargs))
|
|
70
77
|
if self.return_type == "str":
|
|
71
78
|
return response["output"]
|
|
@@ -7,7 +7,7 @@ from langchain_core.callbacks import dispatch_custom_event
|
|
|
7
7
|
from langchain_core.messages import ToolCall
|
|
8
8
|
from langchain_core.runnables import RunnableConfig
|
|
9
9
|
from langchain_core.tools import BaseTool, ToolException
|
|
10
|
-
from typing import Any, Optional, Union
|
|
10
|
+
from typing import Any, Optional, Union
|
|
11
11
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
|
12
12
|
from pydantic import ValidationError
|
|
13
13
|
|
|
@@ -16,6 +16,18 @@ from ..langchain.utils import propagate_the_input_mapping
|
|
|
16
16
|
logger = logging.getLogger(__name__)
|
|
17
17
|
|
|
18
18
|
|
|
19
|
+
def replace_escaped_newlines(data):
|
|
20
|
+
"""
|
|
21
|
+
Replace \\n with \n in all string values recursively.
|
|
22
|
+
Required for sanitization of state variables in code node
|
|
23
|
+
"""
|
|
24
|
+
if isinstance(data, dict):
|
|
25
|
+
return {key: replace_escaped_newlines(value) for key, value in data.items()}
|
|
26
|
+
elif isinstance(data, str):
|
|
27
|
+
return data.replace('\\n', '\n')
|
|
28
|
+
else:
|
|
29
|
+
return data
|
|
30
|
+
|
|
19
31
|
class FunctionTool(BaseTool):
|
|
20
32
|
name: str = 'FunctionalTool'
|
|
21
33
|
description: str = 'This is direct call node for tools'
|
|
@@ -30,29 +42,13 @@ class FunctionTool(BaseTool):
|
|
|
30
42
|
def _prepare_pyodide_input(self, state: Union[str, dict, ToolCall]) -> str:
|
|
31
43
|
"""Prepare input for PyodideSandboxTool by injecting state into the code block."""
|
|
32
44
|
# add state into the code block here since it might be changed during the execution of the code
|
|
33
|
-
state_copy = deepcopy(state)
|
|
45
|
+
state_copy = replace_escaped_newlines(deepcopy(state))
|
|
34
46
|
|
|
35
47
|
del state_copy['messages'] # remove messages to avoid issues with pickling without langchain-core
|
|
36
48
|
# inject state into the code block as alita_state variable
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
try:
|
|
41
|
-
import os
|
|
42
|
-
from pathlib import Path
|
|
43
|
-
|
|
44
|
-
# Get the directory of the current file and construct the path to sandbox_client.py
|
|
45
|
-
current_dir = Path(__file__).parent
|
|
46
|
-
sandbox_client_path = current_dir.parent / 'clients' / 'sandbox_client.py'
|
|
47
|
-
|
|
48
|
-
with open(sandbox_client_path, 'r') as f:
|
|
49
|
-
sandbox_client_code = f.read()
|
|
50
|
-
pyodide_predata += f"\n{sandbox_client_code}\n"
|
|
51
|
-
pyodide_predata += (f"alita_client = SandboxClient(base_url='{self.alita_client.base_url}',"
|
|
52
|
-
f"project_id={self.alita_client.project_id},"
|
|
53
|
-
f"auth_token='{self.alita_client.auth_token}')")
|
|
54
|
-
except FileNotFoundError:
|
|
55
|
-
logger.error(f"sandbox_client.py not found at {sandbox_client_path}. Ensure the file exists.")
|
|
49
|
+
state_json = json.dumps(state_copy, ensure_ascii=False)
|
|
50
|
+
pyodide_predata = f'#state dict\nimport json\nalita_state = json.loads({json.dumps(state_json)})\n'
|
|
51
|
+
|
|
56
52
|
return pyodide_predata
|
|
57
53
|
|
|
58
54
|
def _handle_pyodide_output(self, tool_result: Any) -> dict:
|
|
@@ -132,14 +128,22 @@ class FunctionTool(BaseTool):
|
|
|
132
128
|
if not self.output_variables:
|
|
133
129
|
return {"messages": [{"role": "assistant", "content": dumps(tool_result)}]}
|
|
134
130
|
else:
|
|
135
|
-
if self.output_variables
|
|
136
|
-
|
|
131
|
+
if "messages" in self.output_variables:
|
|
132
|
+
messages_dict = {
|
|
137
133
|
"messages": [{
|
|
138
134
|
"role": "assistant",
|
|
139
|
-
"content": dumps(tool_result)
|
|
140
|
-
|
|
135
|
+
"content": dumps(tool_result)
|
|
136
|
+
if not isinstance(tool_result, ToolException) and not isinstance(tool_result, str)
|
|
137
|
+
else str(tool_result)
|
|
141
138
|
}]
|
|
142
139
|
}
|
|
140
|
+
for var in self.output_variables:
|
|
141
|
+
if var != "messages":
|
|
142
|
+
if isinstance(tool_result, dict) and var in tool_result:
|
|
143
|
+
messages_dict[var] = tool_result[var]
|
|
144
|
+
else:
|
|
145
|
+
messages_dict[var] = tool_result
|
|
146
|
+
return messages_dict
|
|
143
147
|
else:
|
|
144
148
|
return { self.output_variables[0]: tool_result }
|
|
145
149
|
except ValidationError:
|
alita_sdk/runtime/tools/graph.py
CHANGED
|
@@ -47,8 +47,8 @@ def formulate_query(kwargs):
|
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
class GraphTool(BaseTool):
|
|
50
|
-
name: str
|
|
51
|
-
description: str
|
|
50
|
+
name: str = 'GraphTool'
|
|
51
|
+
description: str = 'Graph tool for tools'
|
|
52
52
|
graph: CompiledStateGraph
|
|
53
53
|
args_schema: Type[BaseModel] = graphToolSchema
|
|
54
54
|
return_type: str = "str"
|
|
@@ -65,10 +65,16 @@ class GraphTool(BaseTool):
|
|
|
65
65
|
all_kwargs = {**kwargs, **extras, **schema_values}
|
|
66
66
|
if config is None:
|
|
67
67
|
config = {}
|
|
68
|
-
|
|
68
|
+
# Pass the config to the _run empty or the one passed from the parent executor.
|
|
69
|
+
return self._run(config, **all_kwargs)
|
|
69
70
|
|
|
70
71
|
def _run(self, *args, **kwargs):
|
|
71
|
-
|
|
72
|
+
config = None
|
|
73
|
+
# From invoke method we are passing only 1 arg so it is safe to do this condition and config assignment.
|
|
74
|
+
# Default to None is safe because it will be checked also on the langchain side.
|
|
75
|
+
if args:
|
|
76
|
+
config = args[0]
|
|
77
|
+
response = self.graph.invoke(formulate_query(kwargs), config=config)
|
|
72
78
|
if self.return_type == "str":
|
|
73
79
|
return response["output"]
|
|
74
80
|
else:
|
|
@@ -2,16 +2,59 @@
|
|
|
2
2
|
Image generation tool for Alita SDK.
|
|
3
3
|
"""
|
|
4
4
|
import logging
|
|
5
|
-
from typing import Optional, Type, Any
|
|
6
|
-
from langchain_core.tools import BaseTool
|
|
7
|
-
from pydantic import BaseModel, Field
|
|
5
|
+
from typing import Optional, Type, Any, List, Literal
|
|
6
|
+
from langchain_core.tools import BaseTool, BaseToolkit
|
|
7
|
+
from pydantic import BaseModel, Field, create_model, ConfigDict
|
|
8
8
|
|
|
9
9
|
logger = logging.getLogger(__name__)
|
|
10
10
|
|
|
11
|
+
name = "image_generation"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_tools(tools_list: list, alita_client=None, llm=None,
|
|
15
|
+
memory_store=None):
|
|
16
|
+
"""
|
|
17
|
+
Get image generation tools for the provided tool configurations.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
tools_list: List of tool configurations
|
|
21
|
+
alita_client: Alita client instance (required for image generation)
|
|
22
|
+
llm: LLM client instance (unused for image generation)
|
|
23
|
+
memory_store: Optional memory store instance (unused)
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
List of image generation tools
|
|
27
|
+
"""
|
|
28
|
+
all_tools = []
|
|
29
|
+
|
|
30
|
+
for tool in tools_list:
|
|
31
|
+
if (tool.get('type') == 'image_generation' or
|
|
32
|
+
tool.get('toolkit_name') == 'image_generation'):
|
|
33
|
+
try:
|
|
34
|
+
if not alita_client:
|
|
35
|
+
logger.error("Alita client is required for image "
|
|
36
|
+
"generation tools")
|
|
37
|
+
continue
|
|
38
|
+
|
|
39
|
+
toolkit_instance = ImageGenerationToolkit.get_toolkit(
|
|
40
|
+
client=alita_client,
|
|
41
|
+
toolkit_name=tool.get('toolkit_name', '')
|
|
42
|
+
)
|
|
43
|
+
all_tools.extend(toolkit_instance.get_tools())
|
|
44
|
+
except Exception as e:
|
|
45
|
+
logger.error(f"Error in image generation toolkit "
|
|
46
|
+
f"get_tools: {e}")
|
|
47
|
+
logger.error(f"Tool config: {tool}")
|
|
48
|
+
raise
|
|
49
|
+
|
|
50
|
+
return all_tools
|
|
51
|
+
|
|
11
52
|
|
|
12
53
|
class ImageGenerationInput(BaseModel):
|
|
13
54
|
"""Input schema for image generation tool."""
|
|
14
|
-
prompt: str = Field(
|
|
55
|
+
prompt: str = Field(
|
|
56
|
+
description="Text prompt describing the image to generate"
|
|
57
|
+
)
|
|
15
58
|
n: int = Field(
|
|
16
59
|
default=1, description="Number of images to generate (1-10)",
|
|
17
60
|
ge=1, le=10
|
|
@@ -22,7 +65,7 @@ class ImageGenerationInput(BaseModel):
|
|
|
22
65
|
)
|
|
23
66
|
quality: str = Field(
|
|
24
67
|
default="auto",
|
|
25
|
-
description="Quality of the generated image ('low', 'medium', 'high'
|
|
68
|
+
description="Quality of the generated image ('low', 'medium', 'high')"
|
|
26
69
|
)
|
|
27
70
|
style: Optional[str] = Field(
|
|
28
71
|
default=None, description="Style of the generated image (optional)"
|
|
@@ -69,7 +112,8 @@ class ImageGenerationTool(BaseTool):
|
|
|
69
112
|
else:
|
|
70
113
|
content_chunks.append({
|
|
71
114
|
"type": "text",
|
|
72
|
-
"text": f"Generated {len(images)} images for
|
|
115
|
+
"text": f"Generated {len(images)} images for "
|
|
116
|
+
f"prompt: '{prompt}'"
|
|
73
117
|
})
|
|
74
118
|
|
|
75
119
|
# Add image content for each generated image
|
|
@@ -85,7 +129,8 @@ class ImageGenerationTool(BaseTool):
|
|
|
85
129
|
content_chunks.append({
|
|
86
130
|
"type": "image_url",
|
|
87
131
|
"image_url": {
|
|
88
|
-
"url": f"data:image/png;base64,
|
|
132
|
+
"url": f"data:image/png;base64,"
|
|
133
|
+
f"{image_data['b64_json']}"
|
|
89
134
|
}
|
|
90
135
|
})
|
|
91
136
|
|
|
@@ -94,7 +139,8 @@ class ImageGenerationTool(BaseTool):
|
|
|
94
139
|
# Fallback to text response if no images in result
|
|
95
140
|
return [{
|
|
96
141
|
"type": "text",
|
|
97
|
-
"text": f"Image generation completed but no images
|
|
142
|
+
"text": f"Image generation completed but no images "
|
|
143
|
+
f"returned: {result}"
|
|
98
144
|
}]
|
|
99
145
|
|
|
100
146
|
except Exception as e:
|
|
@@ -114,3 +160,53 @@ class ImageGenerationTool(BaseTool):
|
|
|
114
160
|
def create_image_generation_tool(client):
|
|
115
161
|
"""Create an image generation tool with the provided Alita client."""
|
|
116
162
|
return ImageGenerationTool(client=client)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class ImageGenerationToolkit(BaseToolkit):
|
|
166
|
+
"""Toolkit for image generation tools."""
|
|
167
|
+
tools: List[BaseTool] = []
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def toolkit_config_schema() -> BaseModel:
|
|
171
|
+
"""Get the configuration schema for the image generation toolkit."""
|
|
172
|
+
# Create sample tool to get schema
|
|
173
|
+
sample_tool = ImageGenerationTool(client=None)
|
|
174
|
+
selected_tools = {sample_tool.name: sample_tool.args_schema.schema()}
|
|
175
|
+
|
|
176
|
+
return create_model(
|
|
177
|
+
'image_generation',
|
|
178
|
+
selected_tools=(
|
|
179
|
+
List[Literal[tuple(selected_tools)]],
|
|
180
|
+
Field(
|
|
181
|
+
default=[],
|
|
182
|
+
json_schema_extra={'args_schemas': selected_tools}
|
|
183
|
+
)
|
|
184
|
+
),
|
|
185
|
+
__config__=ConfigDict(json_schema_extra={
|
|
186
|
+
'metadata': {
|
|
187
|
+
"label": "Image Generation",
|
|
188
|
+
"icon_url": "image_generation.svg",
|
|
189
|
+
"hidden": True,
|
|
190
|
+
"categories": ["internal_tool"],
|
|
191
|
+
"extra_categories": ["image generation"],
|
|
192
|
+
}
|
|
193
|
+
})
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
@classmethod
|
|
197
|
+
def get_toolkit(cls, client=None, **kwargs):
|
|
198
|
+
"""
|
|
199
|
+
Get toolkit with image generation tools.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
client: Alita client instance (required)
|
|
203
|
+
**kwargs: Additional arguments
|
|
204
|
+
"""
|
|
205
|
+
if not client:
|
|
206
|
+
raise ValueError("Alita client is required for image generation")
|
|
207
|
+
|
|
208
|
+
tools = [ImageGenerationTool(client=client)]
|
|
209
|
+
return cls(tools=tools)
|
|
210
|
+
|
|
211
|
+
def get_tools(self):
|
|
212
|
+
return self.tools
|