alita-sdk 0.3.562__py3-none-any.whl → 0.3.584__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of alita-sdk might be problematic. Click here for more details.
- alita_sdk/cli/agents.py +358 -165
- alita_sdk/configurations/openapi.py +227 -15
- alita_sdk/runtime/langchain/langraph_agent.py +93 -20
- alita_sdk/runtime/langchain/utils.py +30 -14
- alita_sdk/runtime/toolkits/artifact.py +2 -1
- alita_sdk/runtime/toolkits/mcp.py +4 -2
- alita_sdk/runtime/toolkits/skill_router.py +1 -1
- alita_sdk/runtime/toolkits/vectorstore.py +1 -1
- alita_sdk/runtime/tools/data_analysis.py +1 -1
- alita_sdk/runtime/tools/llm.py +30 -11
- alita_sdk/runtime/utils/constants.py +5 -1
- alita_sdk/tools/ado/repos/__init__.py +2 -1
- alita_sdk/tools/ado/test_plan/__init__.py +2 -1
- alita_sdk/tools/ado/wiki/__init__.py +2 -1
- alita_sdk/tools/ado/work_item/__init__.py +2 -1
- alita_sdk/tools/advanced_jira_mining/__init__.py +2 -1
- alita_sdk/tools/aws/delta_lake/__init__.py +2 -1
- alita_sdk/tools/azure_ai/search/__init__.py +2 -1
- alita_sdk/tools/bitbucket/__init__.py +2 -1
- alita_sdk/tools/browser/__init__.py +1 -1
- alita_sdk/tools/carrier/__init__.py +1 -1
- alita_sdk/tools/cloud/aws/__init__.py +2 -1
- alita_sdk/tools/cloud/azure/__init__.py +2 -1
- alita_sdk/tools/cloud/gcp/__init__.py +2 -1
- alita_sdk/tools/cloud/k8s/__init__.py +2 -1
- alita_sdk/tools/code/linter/__init__.py +2 -1
- alita_sdk/tools/code/sonar/__init__.py +2 -1
- alita_sdk/tools/confluence/__init__.py +2 -1
- alita_sdk/tools/custom_open_api/__init__.py +2 -1
- alita_sdk/tools/elastic/__init__.py +2 -1
- alita_sdk/tools/figma/__init__.py +51 -5
- alita_sdk/tools/figma/api_wrapper.py +1157 -123
- alita_sdk/tools/figma/figma_client.py +73 -0
- alita_sdk/tools/figma/toon_tools.py +2748 -0
- alita_sdk/tools/github/__init__.py +2 -1
- alita_sdk/tools/gitlab/__init__.py +2 -1
- alita_sdk/tools/gitlab/api_wrapper.py +32 -0
- alita_sdk/tools/gitlab_org/__init__.py +2 -1
- alita_sdk/tools/google/bigquery/__init__.py +2 -1
- alita_sdk/tools/google_places/__init__.py +2 -1
- alita_sdk/tools/jira/__init__.py +2 -1
- alita_sdk/tools/keycloak/__init__.py +2 -1
- alita_sdk/tools/localgit/__init__.py +2 -1
- alita_sdk/tools/memory/__init__.py +1 -1
- alita_sdk/tools/ocr/__init__.py +2 -1
- alita_sdk/tools/openapi/__init__.py +227 -15
- alita_sdk/tools/openapi/api_wrapper.py +1276 -802
- alita_sdk/tools/pandas/__init__.py +3 -2
- alita_sdk/tools/postman/__init__.py +2 -1
- alita_sdk/tools/pptx/__init__.py +2 -1
- alita_sdk/tools/qtest/__init__.py +2 -1
- alita_sdk/tools/rally/__init__.py +2 -1
- alita_sdk/tools/report_portal/__init__.py +2 -1
- alita_sdk/tools/salesforce/__init__.py +2 -1
- alita_sdk/tools/servicenow/__init__.py +2 -1
- alita_sdk/tools/sharepoint/__init__.py +2 -1
- alita_sdk/tools/slack/__init__.py +3 -2
- alita_sdk/tools/sql/__init__.py +2 -1
- alita_sdk/tools/testio/__init__.py +2 -1
- alita_sdk/tools/testrail/__init__.py +2 -1
- alita_sdk/tools/utils/content_parser.py +68 -2
- alita_sdk/tools/xray/__init__.py +2 -1
- alita_sdk/tools/yagmail/__init__.py +2 -1
- alita_sdk/tools/zephyr/__init__.py +2 -1
- alita_sdk/tools/zephyr_enterprise/__init__.py +2 -1
- alita_sdk/tools/zephyr_essential/__init__.py +2 -1
- alita_sdk/tools/zephyr_scale/__init__.py +2 -1
- alita_sdk/tools/zephyr_squad/__init__.py +2 -1
- {alita_sdk-0.3.562.dist-info → alita_sdk-0.3.584.dist-info}/METADATA +1 -1
- {alita_sdk-0.3.562.dist-info → alita_sdk-0.3.584.dist-info}/RECORD +74 -72
- {alita_sdk-0.3.562.dist-info → alita_sdk-0.3.584.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.562.dist-info → alita_sdk-0.3.584.dist-info}/entry_points.txt +0 -0
- {alita_sdk-0.3.562.dist-info → alita_sdk-0.3.584.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.562.dist-info → alita_sdk-0.3.584.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,23 @@
|
|
|
1
1
|
from typing import Any, Literal, Optional
|
|
2
2
|
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
|
3
3
|
|
|
4
|
+
import base64
|
|
5
|
+
import requests
|
|
6
|
+
|
|
4
7
|
|
|
5
8
|
class OpenApiConfiguration(BaseModel):
|
|
9
|
+
"""
|
|
10
|
+
OpenAPI configuration for authentication.
|
|
11
|
+
|
|
12
|
+
Supports three authentication modes:
|
|
13
|
+
- Anonymous: No authentication (all fields empty)
|
|
14
|
+
- API Key: Static key sent via header (Bearer, Basic, or Custom)
|
|
15
|
+
- OAuth2 Client Credentials: Machine-to-machine authentication flow
|
|
16
|
+
|
|
17
|
+
Note: Only OAuth2 Client Credentials flow is supported. Authorization Code flow
|
|
18
|
+
is not supported as it requires user interaction and pre-registered redirect URLs.
|
|
19
|
+
"""
|
|
20
|
+
|
|
6
21
|
model_config = ConfigDict(
|
|
7
22
|
extra='allow',
|
|
8
23
|
json_schema_extra={
|
|
@@ -25,7 +40,6 @@ class OpenApiConfiguration(BaseModel):
|
|
|
25
40
|
"fields": [
|
|
26
41
|
"client_id",
|
|
27
42
|
"client_secret",
|
|
28
|
-
"auth_url",
|
|
29
43
|
"token_url",
|
|
30
44
|
"scope",
|
|
31
45
|
"method",
|
|
@@ -39,6 +53,10 @@ class OpenApiConfiguration(BaseModel):
|
|
|
39
53
|
}
|
|
40
54
|
)
|
|
41
55
|
|
|
56
|
+
# =========================================================================
|
|
57
|
+
# API Key Authentication Fields
|
|
58
|
+
# =========================================================================
|
|
59
|
+
|
|
42
60
|
api_key: Optional[SecretStr] = Field(
|
|
43
61
|
default=None,
|
|
44
62
|
description=(
|
|
@@ -60,15 +78,44 @@ class OpenApiConfiguration(BaseModel):
|
|
|
60
78
|
json_schema_extra={'visible_when': {'field': 'auth_type', 'value': 'custom'}},
|
|
61
79
|
)
|
|
62
80
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
81
|
+
# =========================================================================
|
|
82
|
+
# OAuth2 Client Credentials Flow Fields
|
|
83
|
+
# =========================================================================
|
|
84
|
+
|
|
85
|
+
client_id: Optional[str] = Field(
|
|
86
|
+
default=None,
|
|
87
|
+
description='OAuth2 client ID (also known as Application ID or App ID)'
|
|
88
|
+
)
|
|
89
|
+
client_secret: Optional[SecretStr] = Field(
|
|
90
|
+
default=None,
|
|
91
|
+
description='OAuth2 client secret (stored securely)'
|
|
92
|
+
)
|
|
93
|
+
token_url: Optional[str] = Field(
|
|
94
|
+
default=None,
|
|
95
|
+
description=(
|
|
96
|
+
'OAuth2 token endpoint URL for obtaining access tokens. '
|
|
97
|
+
'Examples: '
|
|
98
|
+
'Azure AD: https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token, '
|
|
99
|
+
'Google: https://oauth2.googleapis.com/token, '
|
|
100
|
+
'Auth0: https://{domain}/oauth/token, '
|
|
101
|
+
'Spotify: https://accounts.spotify.com/api/token'
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
scope: Optional[str] = Field(
|
|
105
|
+
default=None,
|
|
106
|
+
description=(
|
|
107
|
+
'OAuth2 scope(s), space-separated if multiple (per OAuth2 RFC 6749). '
|
|
108
|
+
'Examples: "user-read-private user-read-email" (Spotify), '
|
|
109
|
+
'"api://app-id/.default" (Azure), '
|
|
110
|
+
'"https://www.googleapis.com/auth/cloud-platform" (Google)'
|
|
111
|
+
)
|
|
112
|
+
)
|
|
68
113
|
method: Optional[Literal['default', 'Basic']] = Field(
|
|
69
114
|
default='default',
|
|
70
115
|
description=(
|
|
71
|
-
"Token exchange method
|
|
116
|
+
"Token exchange method for client credentials flow. "
|
|
117
|
+
"'default': Sends client_id and client_secret in POST body (Azure AD, Auth0, most providers). "
|
|
118
|
+
"'Basic': Sends credentials via HTTP Basic auth header - required by Spotify, some AWS services, and certain OAuth providers."
|
|
72
119
|
),
|
|
73
120
|
)
|
|
74
121
|
|
|
@@ -78,9 +125,9 @@ class OpenApiConfiguration(BaseModel):
|
|
|
78
125
|
if not isinstance(values, dict):
|
|
79
126
|
return values
|
|
80
127
|
|
|
81
|
-
# OAuth: if any OAuth field is provided, require the
|
|
128
|
+
# OAuth: if any OAuth field is provided, require the essential ones
|
|
82
129
|
has_any_oauth = any(
|
|
83
|
-
(values.get('client_id'), values.get('client_secret'), values.get('
|
|
130
|
+
(values.get('client_id'), values.get('client_secret'), values.get('token_url'))
|
|
84
131
|
)
|
|
85
132
|
if has_any_oauth:
|
|
86
133
|
missing = []
|
|
@@ -93,7 +140,7 @@ class OpenApiConfiguration(BaseModel):
|
|
|
93
140
|
if missing:
|
|
94
141
|
raise ValueError(f"OAuth is misconfigured; missing: {', '.join(missing)}")
|
|
95
142
|
|
|
96
|
-
# API key: if auth_type is custom, custom_header_name must be present
|
|
143
|
+
# API key: if auth_type is custom, custom_header_name must be present
|
|
97
144
|
auth_type = values.get('auth_type')
|
|
98
145
|
if isinstance(auth_type, str) and auth_type.strip().lower() == 'custom' and values.get('api_key'):
|
|
99
146
|
if not values.get('custom_header_name'):
|
|
@@ -103,9 +150,174 @@ class OpenApiConfiguration(BaseModel):
|
|
|
103
150
|
|
|
104
151
|
@staticmethod
|
|
105
152
|
def check_connection(settings: dict) -> str | None:
|
|
106
|
-
"""Best-effort validation for OpenAPI credentials.
|
|
107
|
-
|
|
108
|
-
This model is intended to store reusable credentials only.
|
|
109
|
-
Spec/base_url validation happens at toolkit configuration level.
|
|
110
153
|
"""
|
|
111
|
-
|
|
154
|
+
Validate the OpenAPI configuration by testing connectivity where possible.
|
|
155
|
+
|
|
156
|
+
Validation behavior by authentication type:
|
|
157
|
+
|
|
158
|
+
1. ANONYMOUS (no auth fields configured):
|
|
159
|
+
- Cannot validate without making actual API calls
|
|
160
|
+
- Returns None (success) - validation skipped
|
|
161
|
+
|
|
162
|
+
2. API KEY (api_key field configured):
|
|
163
|
+
- Cannot validate without knowing which endpoint to call
|
|
164
|
+
- The OpenAPI spec is not available at configuration time
|
|
165
|
+
- Returns None (success) - validation skipped
|
|
166
|
+
|
|
167
|
+
3. OAUTH2 CLIENT CREDENTIALS (client_id, client_secret, token_url configured):
|
|
168
|
+
- CAN validate by attempting token exchange with the OAuth provider
|
|
169
|
+
- Makes a real HTTP request to token_url
|
|
170
|
+
- Returns None on success, error message on failure
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
settings: Dictionary containing OpenAPI configuration fields
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
None: Configuration is valid (or cannot be validated for this auth type)
|
|
177
|
+
str: Error message describing the validation failure
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
# =====================================================================
|
|
181
|
+
# Determine authentication type from configured fields
|
|
182
|
+
# =====================================================================
|
|
183
|
+
|
|
184
|
+
client_id = settings.get('client_id')
|
|
185
|
+
client_secret = settings.get('client_secret')
|
|
186
|
+
token_url = settings.get('token_url')
|
|
187
|
+
|
|
188
|
+
has_oauth_fields = client_id or client_secret or token_url
|
|
189
|
+
|
|
190
|
+
# =====================================================================
|
|
191
|
+
# ANONYMOUS or API KEY: Cannot validate, return success
|
|
192
|
+
# =====================================================================
|
|
193
|
+
|
|
194
|
+
if not has_oauth_fields:
|
|
195
|
+
# No OAuth fields configured - this is either:
|
|
196
|
+
# - Anonymous authentication (no auth at all)
|
|
197
|
+
# - API Key authentication (api_key field may be set)
|
|
198
|
+
#
|
|
199
|
+
# Neither can be validated without making actual API calls to the
|
|
200
|
+
# target service, and we don't have the OpenAPI spec available here.
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
# =====================================================================
|
|
204
|
+
# OAUTH2: Validate by attempting token exchange
|
|
205
|
+
# =====================================================================
|
|
206
|
+
|
|
207
|
+
# Check for required OAuth fields
|
|
208
|
+
if not client_id:
|
|
209
|
+
return "OAuth client_id is required when using OAuth authentication"
|
|
210
|
+
if not client_secret:
|
|
211
|
+
return "OAuth client_secret is required when using OAuth authentication"
|
|
212
|
+
if not token_url:
|
|
213
|
+
return "OAuth token_url is required when using OAuth authentication"
|
|
214
|
+
|
|
215
|
+
# Extract secret value if it's a SecretStr
|
|
216
|
+
if hasattr(client_secret, 'get_secret_value'):
|
|
217
|
+
client_secret = client_secret.get_secret_value()
|
|
218
|
+
|
|
219
|
+
if not client_secret or not str(client_secret).strip():
|
|
220
|
+
return "OAuth client_secret cannot be empty"
|
|
221
|
+
|
|
222
|
+
# Validate token_url format
|
|
223
|
+
token_url = token_url.strip()
|
|
224
|
+
if not token_url.startswith(('http://', 'https://')):
|
|
225
|
+
return "OAuth token_url must start with http:// or https://"
|
|
226
|
+
|
|
227
|
+
# Get optional OAuth settings
|
|
228
|
+
scope = settings.get('scope')
|
|
229
|
+
method = settings.get('method', 'default') or 'default'
|
|
230
|
+
|
|
231
|
+
# ---------------------------------------------------------------------
|
|
232
|
+
# Attempt OAuth2 Client Credentials token exchange
|
|
233
|
+
# ---------------------------------------------------------------------
|
|
234
|
+
|
|
235
|
+
try:
|
|
236
|
+
headers = {
|
|
237
|
+
'Content-Type': 'application/x-www-form-urlencoded',
|
|
238
|
+
'Accept': 'application/json',
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
data = {
|
|
242
|
+
'grant_type': 'client_credentials',
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
# Apply credentials based on method
|
|
246
|
+
if method == 'Basic':
|
|
247
|
+
# Basic method: credentials in Authorization header (Spotify, some AWS)
|
|
248
|
+
credentials = f"{client_id}:{client_secret}"
|
|
249
|
+
encoded = base64.b64encode(credentials.encode('utf-8')).decode('utf-8')
|
|
250
|
+
headers['Authorization'] = f'Basic {encoded}'
|
|
251
|
+
else:
|
|
252
|
+
# Default method: credentials in POST body (Azure AD, Auth0, most providers)
|
|
253
|
+
data['client_id'] = client_id
|
|
254
|
+
data['client_secret'] = str(client_secret)
|
|
255
|
+
|
|
256
|
+
if scope:
|
|
257
|
+
data['scope'] = scope
|
|
258
|
+
|
|
259
|
+
response = requests.post(
|
|
260
|
+
token_url,
|
|
261
|
+
headers=headers,
|
|
262
|
+
data=data,
|
|
263
|
+
timeout=30,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# ---------------------------------------------------------------------
|
|
267
|
+
# Handle response
|
|
268
|
+
# ---------------------------------------------------------------------
|
|
269
|
+
|
|
270
|
+
if response.status_code == 200:
|
|
271
|
+
try:
|
|
272
|
+
token_data = response.json()
|
|
273
|
+
if 'access_token' in token_data:
|
|
274
|
+
return None # Success - token obtained
|
|
275
|
+
return "OAuth response did not contain 'access_token'"
|
|
276
|
+
except Exception:
|
|
277
|
+
return "Failed to parse OAuth token response"
|
|
278
|
+
|
|
279
|
+
# Handle common error status codes with helpful messages
|
|
280
|
+
if response.status_code == 400:
|
|
281
|
+
try:
|
|
282
|
+
error_data = response.json()
|
|
283
|
+
error = error_data.get('error', 'bad_request')
|
|
284
|
+
error_desc = error_data.get('error_description', '')
|
|
285
|
+
if error_desc:
|
|
286
|
+
return f"OAuth error: {error} - {error_desc}"
|
|
287
|
+
return f"OAuth error: {error}"
|
|
288
|
+
except Exception:
|
|
289
|
+
return "OAuth request failed: bad request (400)"
|
|
290
|
+
|
|
291
|
+
if response.status_code == 401:
|
|
292
|
+
return "OAuth authentication failed: invalid client_id or client_secret"
|
|
293
|
+
|
|
294
|
+
if response.status_code == 403:
|
|
295
|
+
return "OAuth access forbidden: client may lack required permissions"
|
|
296
|
+
|
|
297
|
+
if response.status_code == 404:
|
|
298
|
+
return f"OAuth token endpoint not found: {token_url}"
|
|
299
|
+
|
|
300
|
+
return f"OAuth token request failed with status {response.status_code}"
|
|
301
|
+
|
|
302
|
+
except requests.exceptions.SSLError as e:
|
|
303
|
+
error_str = str(e).lower()
|
|
304
|
+
if 'hostname mismatch' in error_str:
|
|
305
|
+
return "OAuth token_url hostname does not match SSL certificate - verify the URL is correct"
|
|
306
|
+
if 'certificate verify failed' in error_str:
|
|
307
|
+
return "SSL certificate verification failed for OAuth endpoint - the server may have an invalid or self-signed certificate"
|
|
308
|
+
if 'certificate has expired' in error_str:
|
|
309
|
+
return "SSL certificate has expired for OAuth endpoint"
|
|
310
|
+
return "SSL error connecting to OAuth endpoint - verify the token_url is correct"
|
|
311
|
+
except requests.exceptions.ConnectionError as e:
|
|
312
|
+
error_str = str(e).lower()
|
|
313
|
+
if 'name or service not known' in error_str or 'nodename nor servname provided' in error_str:
|
|
314
|
+
return "OAuth token_url hostname could not be resolved - verify the URL is correct"
|
|
315
|
+
if 'connection refused' in error_str:
|
|
316
|
+
return "Connection refused by OAuth endpoint - verify the token_url and port are correct"
|
|
317
|
+
return "Cannot connect to OAuth token endpoint - verify the token_url is correct"
|
|
318
|
+
except requests.exceptions.Timeout:
|
|
319
|
+
return "OAuth token request timed out - the endpoint may be unreachable"
|
|
320
|
+
except requests.exceptions.RequestException:
|
|
321
|
+
return "OAuth request failed - verify the token_url is correct and accessible"
|
|
322
|
+
except Exception:
|
|
323
|
+
return "Unexpected error during OAuth configuration validation"
|
|
@@ -23,6 +23,7 @@ from langgraph.store.base import BaseStore
|
|
|
23
23
|
from .constants import PRINTER_NODE_RS, PRINTER, PRINTER_COMPLETED_STATE
|
|
24
24
|
from .mixedAgentRenderes import convert_message_to_json
|
|
25
25
|
from .utils import create_state, propagate_the_input_mapping, safe_format
|
|
26
|
+
from ..utils.constants import TOOLKIT_NAME_META, TOOL_NAME_META
|
|
26
27
|
from ..tools.function import FunctionTool
|
|
27
28
|
from ..tools.indexer_tool import IndexerNode
|
|
28
29
|
from ..tools.llm import LLMNode
|
|
@@ -188,7 +189,7 @@ Answer only with step name, no need to add descrip in case none of the steps are
|
|
|
188
189
|
decision_input = state.get('messages', [])[:]
|
|
189
190
|
else:
|
|
190
191
|
if len(additional_info) == 0:
|
|
191
|
-
additional_info = """###
|
|
192
|
+
additional_info = """### Additional info: """
|
|
192
193
|
additional_info += "{field}: {value}\n".format(field=field, value=state.get(field, ""))
|
|
193
194
|
decision_input.append(HumanMessage(
|
|
194
195
|
self.prompt.format(steps=self.steps, description=safe_format(self.description, state), additional_info=additional_info)))
|
|
@@ -447,6 +448,50 @@ def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_befo
|
|
|
447
448
|
return compiled
|
|
448
449
|
|
|
449
450
|
|
|
451
|
+
def find_tool_by_name_or_metadata(tools: list, tool_name: str, toolkit_name: Optional[str] = None) -> Optional[BaseTool]:
|
|
452
|
+
"""
|
|
453
|
+
Find a tool by name or by matching metadata (toolkit_name + tool_name).
|
|
454
|
+
|
|
455
|
+
For toolkit nodes with toolkit_name specified, this function checks:
|
|
456
|
+
1. Metadata match first (toolkit_name + tool_name) - PRIORITY when toolkit_name is provided
|
|
457
|
+
2. Direct tool name match (backward compatibility fallback)
|
|
458
|
+
|
|
459
|
+
For toolkit nodes without toolkit_name, or other node types:
|
|
460
|
+
1. Direct tool name match
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
tools: List of available tools
|
|
464
|
+
tool_name: The tool name to search for
|
|
465
|
+
toolkit_name: Optional toolkit name for metadata matching
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
The matching tool or None if not found
|
|
469
|
+
"""
|
|
470
|
+
# When toolkit_name is specified, prioritize metadata matching
|
|
471
|
+
if toolkit_name:
|
|
472
|
+
for tool in tools:
|
|
473
|
+
# Check metadata match first
|
|
474
|
+
if hasattr(tool, 'metadata') and tool.metadata:
|
|
475
|
+
metadata_toolkit_name = tool.metadata.get(TOOLKIT_NAME_META)
|
|
476
|
+
metadata_tool_name = tool.metadata.get(TOOL_NAME_META)
|
|
477
|
+
|
|
478
|
+
# Match if both toolkit_name and tool_name in metadata match
|
|
479
|
+
if metadata_toolkit_name == toolkit_name and metadata_tool_name == tool_name:
|
|
480
|
+
return tool
|
|
481
|
+
|
|
482
|
+
# Fallback to direct name match for backward compatibility
|
|
483
|
+
for tool in tools:
|
|
484
|
+
if tool.name == tool_name:
|
|
485
|
+
return tool
|
|
486
|
+
else:
|
|
487
|
+
# No toolkit_name specified, use direct name match only
|
|
488
|
+
for tool in tools:
|
|
489
|
+
if tool.name == tool_name:
|
|
490
|
+
return tool
|
|
491
|
+
|
|
492
|
+
return None
|
|
493
|
+
|
|
494
|
+
|
|
450
495
|
def create_graph(
|
|
451
496
|
client: Any,
|
|
452
497
|
yaml_schema: str,
|
|
@@ -482,19 +527,37 @@ def create_graph(
|
|
|
482
527
|
node_type = node.get('type', 'function')
|
|
483
528
|
node_id = clean_string(node['id'])
|
|
484
529
|
toolkit_name = node.get('toolkit_name')
|
|
485
|
-
tool_name = clean_string(node.get('tool',
|
|
530
|
+
tool_name = clean_string(node.get('tool', ''))
|
|
486
531
|
# Tool names are now clean (no prefix needed)
|
|
487
532
|
logger.info(f"Node: {node_id} : {node_type} - {tool_name}")
|
|
488
533
|
if node_type in ['function', 'toolkit', 'mcp', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph', 'pipeline', 'agent']:
|
|
489
|
-
if node_type
|
|
490
|
-
#
|
|
491
|
-
raise ToolException(f"
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
534
|
+
if node_type in ['mcp', 'toolkit', 'agent'] and not tool_name:
|
|
535
|
+
# tool is not specified
|
|
536
|
+
raise ToolException(f"Tool name is required for {node_type} node with id '{node_id}'")
|
|
537
|
+
|
|
538
|
+
# Unified validation and tool finding for toolkit, mcp, and agent node types
|
|
539
|
+
matching_tool = None
|
|
540
|
+
if node_type in ['toolkit', 'mcp', 'agent']:
|
|
541
|
+
# Use enhanced validation that checks both direct name and metadata
|
|
542
|
+
matching_tool = find_tool_by_name_or_metadata(tools, tool_name, toolkit_name)
|
|
543
|
+
if not matching_tool:
|
|
544
|
+
# tool is not found in the provided tools
|
|
545
|
+
error_msg = f"Node `{node_id}` with type `{node_type}` has tool '{tool_name}'"
|
|
546
|
+
if toolkit_name:
|
|
547
|
+
error_msg += f" (toolkit: '{toolkit_name}')"
|
|
548
|
+
error_msg += f" which is not found in the provided tools. Make sure it is connected properly. Available tools: {format_tools(tools)}"
|
|
549
|
+
raise ToolException(error_msg)
|
|
550
|
+
else:
|
|
551
|
+
# For other node types, find tool by direct name match
|
|
552
|
+
for tool in tools:
|
|
553
|
+
if tool.name == tool_name:
|
|
554
|
+
matching_tool = tool
|
|
555
|
+
break
|
|
556
|
+
|
|
557
|
+
if matching_tool:
|
|
495
558
|
if node_type in ['function', 'toolkit', 'mcp']:
|
|
496
559
|
lg_builder.add_node(node_id, FunctionTool(
|
|
497
|
-
tool=
|
|
560
|
+
tool=matching_tool, name=node_id, return_type='dict',
|
|
498
561
|
output_variables=node.get('output', []),
|
|
499
562
|
input_mapping=node.get('input_mapping',
|
|
500
563
|
{'messages': {'type': 'variable', 'value': 'messages'}}),
|
|
@@ -505,7 +568,7 @@ def create_graph(
|
|
|
505
568
|
{'messages': {'type': 'variable', 'value': 'messages'}})
|
|
506
569
|
output_vars = node.get('output', [])
|
|
507
570
|
lg_builder.add_node(node_id, FunctionTool(
|
|
508
|
-
client=client, tool=
|
|
571
|
+
client=client, tool=matching_tool,
|
|
509
572
|
name=node_id, return_type='str',
|
|
510
573
|
output_variables=output_vars + ['messages'] if 'messages' not in output_vars else output_vars,
|
|
511
574
|
input_variables=input_params,
|
|
@@ -513,15 +576,15 @@ def create_graph(
|
|
|
513
576
|
))
|
|
514
577
|
elif node_type == 'subgraph' or node_type == 'pipeline':
|
|
515
578
|
# assign parent memory/store
|
|
516
|
-
#
|
|
517
|
-
#
|
|
579
|
+
# matching_tool.checkpointer = memory
|
|
580
|
+
# matching_tool.store = store
|
|
518
581
|
# wrap with mappings
|
|
519
582
|
pipeline_name = node.get('tool', None)
|
|
520
583
|
if not pipeline_name:
|
|
521
584
|
raise ValueError(
|
|
522
585
|
"Subgraph must have a 'tool' node: add required tool to the subgraph node")
|
|
523
586
|
node_fn = SubgraphRunnable(
|
|
524
|
-
inner=
|
|
587
|
+
inner=matching_tool.graph,
|
|
525
588
|
name=pipeline_name,
|
|
526
589
|
input_mapping=node.get('input_mapping', {}),
|
|
527
590
|
output_mapping=node.get('output_mapping', {}),
|
|
@@ -530,7 +593,7 @@ def create_graph(
|
|
|
530
593
|
break # skip legacy handling
|
|
531
594
|
elif node_type == 'tool':
|
|
532
595
|
lg_builder.add_node(node_id, ToolNode(
|
|
533
|
-
client=client, tool=
|
|
596
|
+
client=client, tool=matching_tool,
|
|
534
597
|
name=node_id, return_type='dict',
|
|
535
598
|
output_variables=node.get('output', []),
|
|
536
599
|
input_variables=node.get('input', ['messages']),
|
|
@@ -539,7 +602,7 @@ def create_graph(
|
|
|
539
602
|
))
|
|
540
603
|
elif node_type == 'loop':
|
|
541
604
|
lg_builder.add_node(node_id, LoopNode(
|
|
542
|
-
client=client, tool=
|
|
605
|
+
client=client, tool=matching_tool,
|
|
543
606
|
name=node_id, return_type='dict',
|
|
544
607
|
output_variables=node.get('output', []),
|
|
545
608
|
input_variables=node.get('input', ['messages']),
|
|
@@ -557,7 +620,7 @@ def create_graph(
|
|
|
557
620
|
lg_builder.add_node(node_id, LoopToolNode(
|
|
558
621
|
client=client,
|
|
559
622
|
name=node_id, return_type='dict',
|
|
560
|
-
tool=
|
|
623
|
+
tool=matching_tool, loop_tool=t,
|
|
561
624
|
variables_mapping=node.get('variables_mapping', {}),
|
|
562
625
|
output_variables=node.get('output', []),
|
|
563
626
|
input_variables=node.get('input', ['messages']),
|
|
@@ -573,7 +636,7 @@ def create_graph(
|
|
|
573
636
|
indexer_tool = t
|
|
574
637
|
logger.info(f"Indexer tool: {indexer_tool}")
|
|
575
638
|
lg_builder.add_node(node_id, IndexerNode(
|
|
576
|
-
client=client, tool=
|
|
639
|
+
client=client, tool=matching_tool,
|
|
577
640
|
index_tool=indexer_tool,
|
|
578
641
|
input_mapping=node.get('input_mapping', {}),
|
|
579
642
|
name=node_id, return_type='dict',
|
|
@@ -582,7 +645,6 @@ def create_graph(
|
|
|
582
645
|
output_variables=node.get('output', []),
|
|
583
646
|
input_variables=node.get('input', ['messages']),
|
|
584
647
|
structured_output=node.get('structured_output', False)))
|
|
585
|
-
break
|
|
586
648
|
elif node_type == 'code':
|
|
587
649
|
from ..tools.sandbox import create_sandbox_tool
|
|
588
650
|
sandbox_tool = create_sandbox_tool(stateful=False, allow_net=True,
|
|
@@ -651,10 +713,13 @@ def create_graph(
|
|
|
651
713
|
))
|
|
652
714
|
elif node_type == 'decision':
|
|
653
715
|
logger.info(f'Adding decision: {node["nodes"]}')
|
|
716
|
+
# fallback to old-style decision node
|
|
717
|
+
decisional_inputs = node.get('decisional_inputs')
|
|
718
|
+
decisional_inputs = node.get('input', ['messages']) if not decisional_inputs else decisional_inputs
|
|
654
719
|
lg_builder.add_node(node_id, DecisionEdge(
|
|
655
720
|
client, node['nodes'],
|
|
656
721
|
node.get('description', ""),
|
|
657
|
-
decisional_inputs=
|
|
722
|
+
decisional_inputs=decisional_inputs,
|
|
658
723
|
default_output=node.get('default_output', 'END'),
|
|
659
724
|
is_node=True
|
|
660
725
|
))
|
|
@@ -750,7 +815,7 @@ def create_graph(
|
|
|
750
815
|
)
|
|
751
816
|
except ValueError as e:
|
|
752
817
|
raise ValueError(
|
|
753
|
-
f"Validation of the schema failed. {e}\n\nDEBUG INFO:**Schema Nodes:**\n\n{lg_builder.nodes}\n\n**Schema
|
|
818
|
+
f"Validation of the schema failed. {e}\n\nDEBUG INFO:**Schema Nodes:**\n\n*{'\n*'.join(lg_builder.nodes.keys())}\n\n**Schema Edges:**\n\n{lg_builder.edges}\n\n**Tools Available:**\n\n{format_tools(tools)}")
|
|
754
819
|
# If building a nested subgraph, return the raw CompiledStateGraph
|
|
755
820
|
if for_subgraph:
|
|
756
821
|
return graph
|
|
@@ -764,6 +829,14 @@ def create_graph(
|
|
|
764
829
|
)
|
|
765
830
|
return compiled.validate()
|
|
766
831
|
|
|
832
|
+
def format_tools(tools_list: list) -> str:
|
|
833
|
+
"""Format a list of tool names into a comma-separated string."""
|
|
834
|
+
try:
|
|
835
|
+
return ', '.join([tool.name for tool in tools_list])
|
|
836
|
+
except Exception as e:
|
|
837
|
+
logger.warning(f"Failed to format tools list: {e}")
|
|
838
|
+
return str(tools_list)
|
|
839
|
+
|
|
767
840
|
def set_defaults(d):
|
|
768
841
|
"""Set default values for dictionary entries based on their type."""
|
|
769
842
|
type_defaults = {
|
|
@@ -2,7 +2,7 @@ import builtins
|
|
|
2
2
|
import json
|
|
3
3
|
import logging
|
|
4
4
|
import re
|
|
5
|
-
from pydantic import create_model, Field,
|
|
5
|
+
from pydantic import create_model, Field, JsonValue
|
|
6
6
|
from typing import Tuple, TypedDict, Any, Optional, Annotated
|
|
7
7
|
from langchain_core.messages import AnyMessage
|
|
8
8
|
from langgraph.graph import add_messages
|
|
@@ -263,17 +263,33 @@ def create_pydantic_model(model_name: str, variables: dict[str, dict]):
|
|
|
263
263
|
return create_model(model_name, **fields)
|
|
264
264
|
|
|
265
265
|
def parse_pydantic_type(type_name: str):
|
|
266
|
-
""
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
266
|
+
t = (type_name or "any").strip().lower()
|
|
267
|
+
|
|
268
|
+
base = {
|
|
269
|
+
"str": str,
|
|
270
|
+
"int": int,
|
|
271
|
+
"float": float,
|
|
272
|
+
"bool": bool,
|
|
273
|
+
# "dict" means JSON object
|
|
274
|
+
"dict": dict[str, JsonValue],
|
|
275
|
+
# "list" means array of JSON values (or pick str if you want)
|
|
276
|
+
"list": list[JsonValue],
|
|
277
|
+
# IMPORTANT: don't return bare Any -> it produces {} schema
|
|
278
|
+
"any": JsonValue,
|
|
278
279
|
}
|
|
279
|
-
|
|
280
|
+
if t in base:
|
|
281
|
+
return base[t]
|
|
282
|
+
|
|
283
|
+
m = re.fullmatch(r"list\[(.+)\]", t)
|
|
284
|
+
if m:
|
|
285
|
+
return list[parse_pydantic_type(m.group(1))]
|
|
286
|
+
|
|
287
|
+
m = re.fullmatch(r"dict\[(.+?),(.+)\]", t)
|
|
288
|
+
if m:
|
|
289
|
+
k = parse_pydantic_type(m.group(1))
|
|
290
|
+
v = parse_pydantic_type(m.group(2))
|
|
291
|
+
# restrict keys to str for JSON objects
|
|
292
|
+
return dict[str, v] if k is not str else dict[str, v]
|
|
293
|
+
|
|
294
|
+
# fallback: avoid Any
|
|
295
|
+
return JsonValue
|
|
@@ -7,6 +7,7 @@ from langchain_core.tools import BaseTool
|
|
|
7
7
|
from pydantic import create_model, BaseModel, ConfigDict, Field
|
|
8
8
|
from pydantic.fields import FieldInfo
|
|
9
9
|
from ..tools.artifact import ArtifactWrapper
|
|
10
|
+
from ..utils.constants import TOOLKIT_NAME_META, TOOLKIT_TYPE_META, TOOL_NAME_META
|
|
10
11
|
from ...tools.base.tool import BaseAction
|
|
11
12
|
from ...configurations.pgvector import PgVectorConfiguration
|
|
12
13
|
|
|
@@ -68,7 +69,7 @@ class ArtifactToolkit(BaseToolkit):
|
|
|
68
69
|
name=tool["name"],
|
|
69
70
|
description=description,
|
|
70
71
|
args_schema=tool["args_schema"],
|
|
71
|
-
metadata={
|
|
72
|
+
metadata={TOOLKIT_NAME_META: toolkit_name, TOOLKIT_TYPE_META: "artifact", TOOL_NAME_META: tool['name']} if toolkit_name else {}
|
|
72
73
|
))
|
|
73
74
|
return cls(tools=tools)
|
|
74
75
|
|
|
@@ -615,7 +615,8 @@ class McpToolkit(BaseToolkit):
|
|
|
615
615
|
args_schema=McpServerTool.create_pydantic_model_from_schema(tool_metadata.input_schema),
|
|
616
616
|
client=client,
|
|
617
617
|
server=tool_metadata.server,
|
|
618
|
-
tool_timeout_sec=timeout
|
|
618
|
+
tool_timeout_sec=timeout,
|
|
619
|
+
metadata={"toolkit_name": toolkit_name, "toolkit_type": name}
|
|
619
620
|
)
|
|
620
621
|
except Exception as e:
|
|
621
622
|
logger.error(f"Failed to create MCP tool '{tool_name}' from server '{tool_metadata.server}': {e}")
|
|
@@ -649,7 +650,8 @@ class McpToolkit(BaseToolkit):
|
|
|
649
650
|
),
|
|
650
651
|
client=client,
|
|
651
652
|
server=toolkit_name,
|
|
652
|
-
tool_timeout_sec=timeout
|
|
653
|
+
tool_timeout_sec=timeout,
|
|
654
|
+
metadata={"toolkit_name": toolkit_name, "toolkit_type": name}
|
|
653
655
|
)
|
|
654
656
|
except Exception as e:
|
|
655
657
|
logger.error(f"Failed to create MCP tool '{tool_name}' from toolkit '{toolkit_name}': {e}")
|
|
@@ -147,7 +147,7 @@ class SkillRouterToolkit(BaseToolkit):
|
|
|
147
147
|
name=tool["name"],
|
|
148
148
|
description=description,
|
|
149
149
|
args_schema=tool["args_schema"],
|
|
150
|
-
metadata={"toolkit_name": toolkit_name} if toolkit_name else {}
|
|
150
|
+
metadata={"toolkit_name": toolkit_name, "toolkit_type": "skill_router"} if toolkit_name else {}
|
|
151
151
|
))
|
|
152
152
|
|
|
153
153
|
return cls(tools=tools)
|
|
@@ -56,7 +56,7 @@ class VectorStoreToolkit(BaseToolkit):
|
|
|
56
56
|
name=tool["name"],
|
|
57
57
|
description=description,
|
|
58
58
|
args_schema=tool["args_schema"],
|
|
59
|
-
metadata={"toolkit_name": toolkit_name} if toolkit_name else {}
|
|
59
|
+
metadata={"toolkit_name": toolkit_name, "toolkit_type": "vectorstore"} if toolkit_name else {}
|
|
60
60
|
))
|
|
61
61
|
return cls(tools=tools)
|
|
62
62
|
|
|
@@ -174,7 +174,7 @@ class DataAnalysisToolkit(BaseToolkit):
|
|
|
174
174
|
name=tool["name"],
|
|
175
175
|
description=description,
|
|
176
176
|
args_schema=tool["args_schema"],
|
|
177
|
-
metadata={"toolkit_name": toolkit_name} if toolkit_name else {}
|
|
177
|
+
metadata={"toolkit_name": toolkit_name, "toolkit_type": name} if toolkit_name else {}
|
|
178
178
|
))
|
|
179
179
|
|
|
180
180
|
return cls(tools=tools)
|