alita-sdk 0.3.373__py3-none-any.whl → 0.3.374__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of alita-sdk might be problematic. Click here for more details.
- alita_sdk/runtime/clients/artifact.py +1 -1
- alita_sdk/runtime/clients/sandbox_client.py +365 -0
- alita_sdk/runtime/langchain/assistant.py +4 -2
- alita_sdk/runtime/langchain/langraph_agent.py +12 -0
- alita_sdk/runtime/tools/function.py +71 -0
- alita_sdk/runtime/tools/sandbox.py +16 -18
- {alita_sdk-0.3.373.dist-info → alita_sdk-0.3.374.dist-info}/METADATA +1 -1
- {alita_sdk-0.3.373.dist-info → alita_sdk-0.3.374.dist-info}/RECORD +11 -10
- {alita_sdk-0.3.373.dist-info → alita_sdk-0.3.374.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.373.dist-info → alita_sdk-0.3.374.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.373.dist-info → alita_sdk-0.3.374.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, Optional
|
|
3
|
+
from urllib.parse import quote
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
from typing import Any
|
|
7
|
+
from json import dumps
|
|
8
|
+
import chardet
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ApiDetailsRequestError(Exception):
|
|
14
|
+
...
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SandboxArtifact:
|
|
18
|
+
def __init__(self, client: Any, bucket_name: str):
|
|
19
|
+
self.client = client
|
|
20
|
+
self.bucket_name = bucket_name
|
|
21
|
+
if not self.client.bucket_exists(bucket_name):
|
|
22
|
+
self.client.create_bucket(bucket_name)
|
|
23
|
+
|
|
24
|
+
def create(self, artifact_name: str, artifact_data: Any, bucket_name: str = None):
|
|
25
|
+
try:
|
|
26
|
+
if not bucket_name:
|
|
27
|
+
bucket_name = self.bucket_name
|
|
28
|
+
return dumps(self.client.create_artifact(bucket_name, artifact_name, artifact_data))
|
|
29
|
+
except Exception as e:
|
|
30
|
+
logger.error(f'Error: {e}')
|
|
31
|
+
return f'Error: {e}'
|
|
32
|
+
|
|
33
|
+
def get(self,
|
|
34
|
+
artifact_name: str,
|
|
35
|
+
bucket_name: str = None,
|
|
36
|
+
is_capture_image: bool = False,
|
|
37
|
+
page_number: int = None,
|
|
38
|
+
sheet_name: str = None,
|
|
39
|
+
excel_by_sheets: bool = False,
|
|
40
|
+
llm=None):
|
|
41
|
+
if not bucket_name:
|
|
42
|
+
bucket_name = self.bucket_name
|
|
43
|
+
data = self.client.download_artifact(bucket_name, artifact_name)
|
|
44
|
+
if len(data) == 0:
|
|
45
|
+
# empty file might be created
|
|
46
|
+
return ''
|
|
47
|
+
if isinstance(data, dict) and data['error']:
|
|
48
|
+
return f'{data['error']}. {data['content'] if data['content'] else ''}'
|
|
49
|
+
detected = chardet.detect(data)
|
|
50
|
+
return data
|
|
51
|
+
# TODO: add proper handling for binary files (images, pdf, etc.) for sandbox
|
|
52
|
+
# if detected['encoding'] is not None:
|
|
53
|
+
# try:
|
|
54
|
+
# return data.decode(detected['encoding'])
|
|
55
|
+
# except Exception:
|
|
56
|
+
# logger.error('Error while default encoding')
|
|
57
|
+
# return parse_file_content(file_name=artifact_name,
|
|
58
|
+
# file_content=data,
|
|
59
|
+
# is_capture_image=is_capture_image,
|
|
60
|
+
# page_number=page_number,
|
|
61
|
+
# sheet_name=sheet_name,
|
|
62
|
+
# excel_by_sheets=excel_by_sheets,
|
|
63
|
+
# llm=llm)
|
|
64
|
+
# else:
|
|
65
|
+
# return parse_file_content(file_name=artifact_name,
|
|
66
|
+
# file_content=data,
|
|
67
|
+
# is_capture_image=is_capture_image,
|
|
68
|
+
# page_number=page_number,
|
|
69
|
+
# sheet_name=sheet_name,
|
|
70
|
+
# excel_by_sheets=excel_by_sheets,
|
|
71
|
+
# llm=llm)
|
|
72
|
+
|
|
73
|
+
def delete(self, artifact_name: str, bucket_name=None):
|
|
74
|
+
if not bucket_name:
|
|
75
|
+
bucket_name = self.bucket_name
|
|
76
|
+
self.client.delete_artifact(bucket_name, artifact_name)
|
|
77
|
+
|
|
78
|
+
def list(self, bucket_name: str = None, return_as_string=True) -> str | dict:
|
|
79
|
+
if not bucket_name:
|
|
80
|
+
bucket_name = self.bucket_name
|
|
81
|
+
artifacts = self.client.list_artifacts(bucket_name)
|
|
82
|
+
return str(artifacts) if return_as_string else artifacts
|
|
83
|
+
|
|
84
|
+
def append(self, artifact_name: str, additional_data: Any, bucket_name: str = None):
|
|
85
|
+
if not bucket_name:
|
|
86
|
+
bucket_name = self.bucket_name
|
|
87
|
+
data = self.get(artifact_name, bucket_name)
|
|
88
|
+
if data == 'Could not detect encoding':
|
|
89
|
+
return data
|
|
90
|
+
data += f'{additional_data}' if len(data) > 0 else additional_data
|
|
91
|
+
self.client.create_artifact(bucket_name, artifact_name, data)
|
|
92
|
+
return 'Data appended successfully'
|
|
93
|
+
|
|
94
|
+
def overwrite(self, artifact_name: str, new_data: Any, bucket_name: str = None):
|
|
95
|
+
if not bucket_name:
|
|
96
|
+
bucket_name = self.bucket_name
|
|
97
|
+
return self.create(artifact_name, new_data, bucket_name)
|
|
98
|
+
|
|
99
|
+
def get_content_bytes(self,
|
|
100
|
+
artifact_name: str,
|
|
101
|
+
bucket_name: str = None):
|
|
102
|
+
if not bucket_name:
|
|
103
|
+
bucket_name = self.bucket_name
|
|
104
|
+
return self.client.download_artifact(bucket_name, artifact_name)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class SandboxClient:
|
|
108
|
+
def __init__(self,
|
|
109
|
+
base_url: str,
|
|
110
|
+
project_id: int,
|
|
111
|
+
auth_token: str,
|
|
112
|
+
api_extra_headers: Optional[dict] = None,
|
|
113
|
+
configurations: Optional[list] = None,
|
|
114
|
+
**kwargs):
|
|
115
|
+
|
|
116
|
+
self.base_url = base_url.rstrip('/')
|
|
117
|
+
self.api_path = '/api/v1'
|
|
118
|
+
self.llm_path = '/llm/v1'
|
|
119
|
+
self.project_id = project_id
|
|
120
|
+
self.auth_token = auth_token
|
|
121
|
+
self.headers = {
|
|
122
|
+
'Authorization': f'Bearer {auth_token}',
|
|
123
|
+
'X-SECRET': kwargs.get('XSECRET', 'secret')
|
|
124
|
+
}
|
|
125
|
+
if api_extra_headers is not None:
|
|
126
|
+
self.headers.update(api_extra_headers)
|
|
127
|
+
self.predict_url = f'{self.base_url}{self.api_path}/prompt_lib/predict/prompt_lib/{self.project_id}'
|
|
128
|
+
self.prompt_versions = f'{self.base_url}{self.api_path}/prompt_lib/version/prompt_lib/{self.project_id}'
|
|
129
|
+
self.prompts = f'{self.base_url}{self.api_path}/prompt_lib/prompt/prompt_lib/{self.project_id}'
|
|
130
|
+
self.datasources = f'{self.base_url}{self.api_path}/datasources/datasource/prompt_lib/{self.project_id}'
|
|
131
|
+
self.datasources_predict = f'{self.base_url}{self.api_path}/datasources/predict/prompt_lib/{self.project_id}'
|
|
132
|
+
self.datasources_search = f'{self.base_url}{self.api_path}/datasources/search/prompt_lib/{self.project_id}'
|
|
133
|
+
self.app = f'{self.base_url}{self.api_path}/applications/application/prompt_lib/{self.project_id}'
|
|
134
|
+
self.mcp_tools_list = f'{self.base_url}{self.api_path}/mcp_sse/tools_list/{self.project_id}'
|
|
135
|
+
self.mcp_tools_call = f'{self.base_url}{self.api_path}/mcp_sse/tools_call/{self.project_id}'
|
|
136
|
+
self.application_versions = f'{self.base_url}{self.api_path}/applications/version/prompt_lib/{self.project_id}'
|
|
137
|
+
self.list_apps_url = f'{self.base_url}{self.api_path}/applications/applications/prompt_lib/{self.project_id}'
|
|
138
|
+
self.integration_details = f'{self.base_url}{self.api_path}/integrations/integration/{self.project_id}'
|
|
139
|
+
self.secrets_url = f'{self.base_url}{self.api_path}/secrets/secret/{self.project_id}'
|
|
140
|
+
self.artifacts_url = f'{self.base_url}{self.api_path}/artifacts/artifacts/default/{self.project_id}'
|
|
141
|
+
self.artifact_url = f'{self.base_url}{self.api_path}/artifacts/artifact/default/{self.project_id}'
|
|
142
|
+
self.bucket_url = f'{self.base_url}{self.api_path}/artifacts/buckets/{self.project_id}'
|
|
143
|
+
self.configurations_url = f'{self.base_url}{self.api_path}/integrations/integrations/default/{self.project_id}?section=configurations&unsecret=true'
|
|
144
|
+
self.ai_section_url = f'{self.base_url}{self.api_path}/integrations/integrations/default/{self.project_id}?section=ai'
|
|
145
|
+
self.image_generation_url = f'{self.base_url}{self.llm_path}/images/generations'
|
|
146
|
+
self.configurations: list = configurations or []
|
|
147
|
+
self.model_timeout = kwargs.get('model_timeout', 120)
|
|
148
|
+
self.model_image_generation = kwargs.get('model_image_generation')
|
|
149
|
+
|
|
150
|
+
def get_mcp_toolkits(self):
|
|
151
|
+
if user_id := self._get_real_user_id():
|
|
152
|
+
url = f'{self.mcp_tools_list}/{user_id}'
|
|
153
|
+
data = requests.get(url, headers=self.headers, verify=False).json()
|
|
154
|
+
return data
|
|
155
|
+
else:
|
|
156
|
+
return []
|
|
157
|
+
|
|
158
|
+
def mcp_tool_call(self, params: dict[str, Any]):
|
|
159
|
+
if user_id := self._get_real_user_id():
|
|
160
|
+
url = f'{self.mcp_tools_call}/{user_id}'
|
|
161
|
+
#
|
|
162
|
+
# This loop iterates over each key-value pair in the arguments dictionary,
|
|
163
|
+
# and if a value is a Pydantic object, it replaces it with its dictionary representation using .dict().
|
|
164
|
+
for arg_name, arg_value in params.get('params', {}).get('arguments', {}).items():
|
|
165
|
+
if isinstance(arg_value, list):
|
|
166
|
+
params['params']['arguments'][arg_name] = [
|
|
167
|
+
item.dict() if hasattr(item, 'dict') and callable(item.dict) else item
|
|
168
|
+
for item in arg_value
|
|
169
|
+
]
|
|
170
|
+
elif hasattr(arg_value, 'dict') and callable(arg_value.dict):
|
|
171
|
+
params['params']['arguments'][arg_name] = arg_value.dict()
|
|
172
|
+
#
|
|
173
|
+
response = requests.post(url, headers=self.headers, json=params, verify=False)
|
|
174
|
+
try:
|
|
175
|
+
return response.json()
|
|
176
|
+
except (ValueError, TypeError):
|
|
177
|
+
return response.text
|
|
178
|
+
else:
|
|
179
|
+
return f'Error: Could not determine user ID for MCP tool call'
|
|
180
|
+
|
|
181
|
+
def get_app_details(self, application_id: int):
|
|
182
|
+
url = f'{self.app}/{application_id}'
|
|
183
|
+
data = requests.get(url, headers=self.headers, verify=False).json()
|
|
184
|
+
return data
|
|
185
|
+
|
|
186
|
+
def get_list_of_apps(self):
|
|
187
|
+
apps = []
|
|
188
|
+
limit = 10
|
|
189
|
+
offset = 0
|
|
190
|
+
total_count = None
|
|
191
|
+
|
|
192
|
+
while total_count is None or offset < total_count:
|
|
193
|
+
params = {'offset': offset, 'limit': limit}
|
|
194
|
+
resp = requests.get(self.list_apps_url, headers=self.headers, params=params, verify=False)
|
|
195
|
+
|
|
196
|
+
if resp.ok:
|
|
197
|
+
data = resp.json()
|
|
198
|
+
total_count = data.get('total')
|
|
199
|
+
apps.extend([{'name': app['name'], 'id': app['id']} for app in data.get('rows', [])])
|
|
200
|
+
offset += limit
|
|
201
|
+
else:
|
|
202
|
+
break
|
|
203
|
+
|
|
204
|
+
return apps
|
|
205
|
+
|
|
206
|
+
def fetch_available_configurations(self) -> list:
|
|
207
|
+
resp = requests.get(self.configurations_url, headers=self.headers, verify=False)
|
|
208
|
+
if resp.ok:
|
|
209
|
+
return resp.json()
|
|
210
|
+
return []
|
|
211
|
+
|
|
212
|
+
def all_models_and_integrations(self):
|
|
213
|
+
resp = requests.get(self.ai_section_url, headers=self.headers, verify=False)
|
|
214
|
+
if resp.ok:
|
|
215
|
+
return resp.json()
|
|
216
|
+
return []
|
|
217
|
+
|
|
218
|
+
def generate_image(self,
|
|
219
|
+
prompt: str,
|
|
220
|
+
n: int = 1,
|
|
221
|
+
size: str = 'auto',
|
|
222
|
+
quality: str = 'auto',
|
|
223
|
+
response_format: str = 'b64_json',
|
|
224
|
+
style: Optional[str] = None) -> dict:
|
|
225
|
+
|
|
226
|
+
if not self.model_image_generation:
|
|
227
|
+
raise ValueError('Image generation model is not configured for this client')
|
|
228
|
+
|
|
229
|
+
image_generation_data = {
|
|
230
|
+
'prompt': prompt,
|
|
231
|
+
'model': self.model_image_generation,
|
|
232
|
+
'n': n,
|
|
233
|
+
'response_format': response_format,
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
# Only add optional parameters if they have meaningful values
|
|
237
|
+
if size and size.lower() != 'auto':
|
|
238
|
+
image_generation_data['size'] = size
|
|
239
|
+
|
|
240
|
+
if quality and quality.lower() != 'auto':
|
|
241
|
+
image_generation_data['quality'] = quality
|
|
242
|
+
|
|
243
|
+
if style:
|
|
244
|
+
image_generation_data['style'] = style
|
|
245
|
+
|
|
246
|
+
# Standard headers for image generation
|
|
247
|
+
image_headers = self.headers.copy()
|
|
248
|
+
image_headers.update({
|
|
249
|
+
'Content-Type': 'application/json',
|
|
250
|
+
})
|
|
251
|
+
|
|
252
|
+
logger.info(f'Generating image with model: {self.model_image_generation}, prompt: {prompt[:50]}...')
|
|
253
|
+
|
|
254
|
+
try:
|
|
255
|
+
response = requests.post(
|
|
256
|
+
self.image_generation_url,
|
|
257
|
+
headers=image_headers,
|
|
258
|
+
json=image_generation_data,
|
|
259
|
+
verify=False,
|
|
260
|
+
timeout=self.model_timeout
|
|
261
|
+
)
|
|
262
|
+
response.raise_for_status()
|
|
263
|
+
return response.json()
|
|
264
|
+
|
|
265
|
+
except requests.exceptions.HTTPError as e:
|
|
266
|
+
logger.error(f'Image generation failed: {e.response.status_code} - {e.response.text}')
|
|
267
|
+
raise
|
|
268
|
+
except requests.exceptions.RequestException as e:
|
|
269
|
+
logger.error(f'Image generation request failed: {e}')
|
|
270
|
+
raise
|
|
271
|
+
|
|
272
|
+
def get_app_version_details(self, application_id: int, application_version_id: int) -> dict:
|
|
273
|
+
url = f'{self.application_versions}/{application_id}/{application_version_id}'
|
|
274
|
+
if self.configurations:
|
|
275
|
+
configs = self.configurations
|
|
276
|
+
else:
|
|
277
|
+
configs = self.fetch_available_configurations()
|
|
278
|
+
|
|
279
|
+
resp = requests.patch(url, headers=self.headers, verify=False, json={'configurations': configs})
|
|
280
|
+
if resp.ok:
|
|
281
|
+
return resp.json()
|
|
282
|
+
logger.error(f'Failed to fetch application version details: {resp.status_code} - {resp.text}.'
|
|
283
|
+
f' Application ID: {application_id}, Version ID: {application_version_id}')
|
|
284
|
+
raise ApiDetailsRequestError(
|
|
285
|
+
f'Failed to fetch application version details for {application_id}/{application_version_id}.')
|
|
286
|
+
|
|
287
|
+
def get_integration_details(self, integration_id: str, format_for_model: bool = False):
|
|
288
|
+
url = f'{self.integration_details}/{integration_id}'
|
|
289
|
+
data = requests.get(url, headers=self.headers, verify=False).json()
|
|
290
|
+
return data
|
|
291
|
+
|
|
292
|
+
def unsecret(self, secret_name: str):
|
|
293
|
+
url = f'{self.secrets_url}/{secret_name}'
|
|
294
|
+
data = requests.get(url, headers=self.headers, verify=False).json()
|
|
295
|
+
logger.info(f'Unsecret response: {data}')
|
|
296
|
+
return data.get('value', None)
|
|
297
|
+
|
|
298
|
+
def artifact(self, bucket_name):
|
|
299
|
+
return SandboxArtifact(self, bucket_name)
|
|
300
|
+
|
|
301
|
+
def _process_requst(self, data: requests.Response) -> Dict[str, str]:
|
|
302
|
+
if data.status_code == 403:
|
|
303
|
+
return {'error': 'You are not authorized to access this resource'}
|
|
304
|
+
elif data.status_code == 404:
|
|
305
|
+
return {'error': 'Resource not found'}
|
|
306
|
+
elif data.status_code != 200:
|
|
307
|
+
return {
|
|
308
|
+
'error': 'An error occurred while fetching the resource',
|
|
309
|
+
'content': data.text
|
|
310
|
+
}
|
|
311
|
+
else:
|
|
312
|
+
return data.json()
|
|
313
|
+
|
|
314
|
+
def bucket_exists(self, bucket_name):
|
|
315
|
+
try:
|
|
316
|
+
resp = self._process_requst(
|
|
317
|
+
requests.get(f'{self.bucket_url}', headers=self.headers, verify=False)
|
|
318
|
+
)
|
|
319
|
+
for each in resp.get('rows', []):
|
|
320
|
+
if each['name'] == bucket_name:
|
|
321
|
+
return True
|
|
322
|
+
return False
|
|
323
|
+
except:
|
|
324
|
+
return False
|
|
325
|
+
|
|
326
|
+
def create_bucket(self, bucket_name, expiration_measure='months', expiration_value=1):
|
|
327
|
+
post_data = {
|
|
328
|
+
'name': bucket_name,
|
|
329
|
+
'expiration_measure': expiration_measure,
|
|
330
|
+
'expiration_value': expiration_value
|
|
331
|
+
}
|
|
332
|
+
resp = requests.post(f'{self.bucket_url}', headers=self.headers, json=post_data, verify=False)
|
|
333
|
+
return self._process_requst(resp)
|
|
334
|
+
|
|
335
|
+
def list_artifacts(self, bucket_name: str):
|
|
336
|
+
# Ensure bucket name is lowercase as required by the API
|
|
337
|
+
url = f'{self.artifacts_url}/{bucket_name.lower()}'
|
|
338
|
+
data = requests.get(url, headers=self.headers, verify=False)
|
|
339
|
+
return self._process_requst(data)
|
|
340
|
+
|
|
341
|
+
def create_artifact(self, bucket_name, artifact_name, artifact_data):
|
|
342
|
+
url = f'{self.artifacts_url}/{bucket_name.lower()}'
|
|
343
|
+
data = requests.post(url, headers=self.headers, files={
|
|
344
|
+
'file': (artifact_name, artifact_data)
|
|
345
|
+
}, verify=False)
|
|
346
|
+
return self._process_requst(data)
|
|
347
|
+
|
|
348
|
+
def download_artifact(self, bucket_name, artifact_name):
|
|
349
|
+
url = f'{self.artifact_url}/{bucket_name.lower()}/{artifact_name}'
|
|
350
|
+
data = requests.get(url, headers=self.headers, verify=False)
|
|
351
|
+
if data.status_code == 403:
|
|
352
|
+
return {'error': 'You are not authorized to access this resource'}
|
|
353
|
+
elif data.status_code == 404:
|
|
354
|
+
return {'error': 'Resource not found'}
|
|
355
|
+
elif data.status_code != 200:
|
|
356
|
+
return {
|
|
357
|
+
'error': 'An error occurred while fetching the resource',
|
|
358
|
+
'content': data.content
|
|
359
|
+
}
|
|
360
|
+
return data.content
|
|
361
|
+
|
|
362
|
+
def delete_artifact(self, bucket_name, artifact_name):
|
|
363
|
+
url = f'{self.artifact_url}/{bucket_name}'
|
|
364
|
+
data = requests.delete(url, headers=self.headers, verify=False, params={'filename': quote(artifact_name)})
|
|
365
|
+
return self._process_requst(data)
|
|
@@ -314,7 +314,8 @@ class Assistant:
|
|
|
314
314
|
memory=checkpointer,
|
|
315
315
|
store=self.store,
|
|
316
316
|
debug=False,
|
|
317
|
-
for_subgraph=False
|
|
317
|
+
for_subgraph=False,
|
|
318
|
+
alita_client=self.alita_client
|
|
318
319
|
)
|
|
319
320
|
|
|
320
321
|
return agent
|
|
@@ -328,7 +329,8 @@ class Assistant:
|
|
|
328
329
|
#
|
|
329
330
|
agent = create_graph(
|
|
330
331
|
client=self.client, tools=self.tools,
|
|
331
|
-
yaml_schema=self.prompt, memory=memory
|
|
332
|
+
yaml_schema=self.prompt, memory=memory,
|
|
333
|
+
alita_client=self.alita_client
|
|
332
334
|
)
|
|
333
335
|
#
|
|
334
336
|
return agent
|
|
@@ -553,6 +553,18 @@ def create_graph(
|
|
|
553
553
|
input_variables=node.get('input', ['messages']),
|
|
554
554
|
structured_output=node.get('structured_output', False)))
|
|
555
555
|
break
|
|
556
|
+
elif node_type == 'code':
|
|
557
|
+
from ..tools.sandbox import create_sandbox_tool
|
|
558
|
+
sandbox_tool = create_sandbox_tool(stateful=False, allow_net=True)
|
|
559
|
+
code = node.get('code', "return 'Code block is empty'")
|
|
560
|
+
lg_builder.add_node(node_id, FunctionTool(
|
|
561
|
+
tool=sandbox_tool, name=node['id'], return_type='dict',
|
|
562
|
+
output_variables=node.get('output', []),
|
|
563
|
+
input_mapping={'code': {'type': 'fixed', 'value': code}},
|
|
564
|
+
input_variables=node.get('input', ['messages']),
|
|
565
|
+
structured_output=node.get('structured_output', False),
|
|
566
|
+
alita_client=kwargs.get('alita_client', None)
|
|
567
|
+
))
|
|
556
568
|
elif node_type == 'llm':
|
|
557
569
|
output_vars = node.get('output', [])
|
|
558
570
|
output_vars_dict = {
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import logging
|
|
3
|
+
from copy import deepcopy
|
|
2
4
|
from json import dumps
|
|
3
5
|
|
|
4
6
|
from langchain_core.callbacks import dispatch_custom_event
|
|
@@ -8,6 +10,7 @@ from langchain_core.tools import BaseTool, ToolException
|
|
|
8
10
|
from typing import Any, Optional, Union, Annotated
|
|
9
11
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
|
10
12
|
from pydantic import ValidationError
|
|
13
|
+
|
|
11
14
|
from ..langchain.utils import propagate_the_input_mapping
|
|
12
15
|
|
|
13
16
|
logger = logging.getLogger(__name__)
|
|
@@ -21,6 +24,63 @@ class FunctionTool(BaseTool):
|
|
|
21
24
|
input_variables: Optional[list[str]] = None
|
|
22
25
|
input_mapping: Optional[dict[str, dict]] = None
|
|
23
26
|
output_variables: Optional[list[str]] = None
|
|
27
|
+
structured_output: Optional[bool] = False
|
|
28
|
+
alita_client: Optional[Any] = None
|
|
29
|
+
|
|
30
|
+
def _prepare_pyodide_input(self, state: Union[str, dict, ToolCall]) -> str:
|
|
31
|
+
"""Prepare input for PyodideSandboxTool by injecting state into the code block."""
|
|
32
|
+
# add state into the code block here since it might be changed during the execution of the code
|
|
33
|
+
state_copy = deepcopy(state)
|
|
34
|
+
# pickle state
|
|
35
|
+
import pickle
|
|
36
|
+
|
|
37
|
+
del state_copy['messages'] # remove messages to avoid issues with pickling without langchain-core
|
|
38
|
+
serialized_state = pickle.dumps(state_copy)
|
|
39
|
+
# inject state into the code block as alita_state variable
|
|
40
|
+
pyodide_predata = f"""import pickle\nalita_state = pickle.loads({serialized_state})\n"""
|
|
41
|
+
# add classes related to sandbox client
|
|
42
|
+
# read the content of alita_sdk/runtime/cliens/sandbox_client.py
|
|
43
|
+
try:
|
|
44
|
+
with open('alita_sdk/runtime/clients/sandbox_client.py', 'r') as f:
|
|
45
|
+
sandbox_client_code = f.read()
|
|
46
|
+
pyodide_predata += f"\n{sandbox_client_code}\n"
|
|
47
|
+
pyodide_predata += (f"alita_client = SandboxClient(base_url='{self.alita_client.base_url}',"
|
|
48
|
+
f"project_id={self.alita_client.project_id},"
|
|
49
|
+
f"auth_token='{self.alita_client.auth_token}')")
|
|
50
|
+
except FileNotFoundError:
|
|
51
|
+
logger.error("sandbox_client.py not found. Ensure 'alita_sdk/runtime/clients/sandbox_client.py' exists.")
|
|
52
|
+
return pyodide_predata
|
|
53
|
+
|
|
54
|
+
def _handle_pyodide_output(self, tool_result: Any) -> dict:
|
|
55
|
+
"""Handle output processing for PyodideSandboxTool results."""
|
|
56
|
+
tool_result_converted = {}
|
|
57
|
+
|
|
58
|
+
if self.output_variables:
|
|
59
|
+
for var in self.output_variables:
|
|
60
|
+
if isinstance(tool_result, dict) and var in tool_result:
|
|
61
|
+
tool_result_converted[var] = tool_result[var]
|
|
62
|
+
else:
|
|
63
|
+
# handler in case user points to a var that is not in the output of the tool
|
|
64
|
+
tool_result_converted[var] = tool_result.get('result',
|
|
65
|
+
tool_result.get('error') if tool_result.get('error')
|
|
66
|
+
else 'Execution result is missing')
|
|
67
|
+
else:
|
|
68
|
+
tool_result_converted.update({"messages": [{"role": "assistant", "content": dumps(tool_result)}]})
|
|
69
|
+
|
|
70
|
+
if self.structured_output:
|
|
71
|
+
# execute code tool and update state variables
|
|
72
|
+
try:
|
|
73
|
+
result_value = tool_result.get('result', {})
|
|
74
|
+
tool_result_converted.update(result_value if isinstance(result_value, dict)
|
|
75
|
+
else json.loads(result_value))
|
|
76
|
+
except json.JSONDecodeError:
|
|
77
|
+
logger.error(f"JSONDecodeError: {tool_result}")
|
|
78
|
+
|
|
79
|
+
return tool_result_converted
|
|
80
|
+
|
|
81
|
+
def _is_pyodide_tool(self) -> bool:
|
|
82
|
+
"""Check if the current tool is a PyodideSandboxTool."""
|
|
83
|
+
return self.tool.name.lower() == 'pyodide_sandbox'
|
|
24
84
|
|
|
25
85
|
def invoke(
|
|
26
86
|
self,
|
|
@@ -31,8 +91,14 @@ class FunctionTool(BaseTool):
|
|
|
31
91
|
params = convert_to_openai_tool(self.tool).get(
|
|
32
92
|
'function', {'parameters': {}}).get(
|
|
33
93
|
'parameters', {'properties': {}}).get('properties', {})
|
|
94
|
+
|
|
34
95
|
func_args = propagate_the_input_mapping(input_mapping=self.input_mapping, input_variables=self.input_variables,
|
|
35
96
|
state=state)
|
|
97
|
+
|
|
98
|
+
# special handler for PyodideSandboxTool
|
|
99
|
+
if self._is_pyodide_tool():
|
|
100
|
+
code = func_args['code']
|
|
101
|
+
func_args['code'] = f"{self._prepare_pyodide_input(state)}\n{code}"
|
|
36
102
|
try:
|
|
37
103
|
tool_result = self.tool.invoke(func_args, config, **kwargs)
|
|
38
104
|
dispatch_custom_event(
|
|
@@ -44,6 +110,11 @@ class FunctionTool(BaseTool):
|
|
|
44
110
|
}, config=config
|
|
45
111
|
)
|
|
46
112
|
logger.info(f"ToolNode response: {tool_result}")
|
|
113
|
+
|
|
114
|
+
# handler for PyodideSandboxTool
|
|
115
|
+
if self._is_pyodide_tool():
|
|
116
|
+
return self._handle_pyodide_output(tool_result)
|
|
117
|
+
|
|
47
118
|
if not self.output_variables:
|
|
48
119
|
return {"messages": [{"role": "assistant", "content": dumps(tool_result)}]}
|
|
49
120
|
else:
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import logging
|
|
2
1
|
import asyncio
|
|
2
|
+
import logging
|
|
3
3
|
import subprocess
|
|
4
4
|
import os
|
|
5
|
-
from typing import Any, Type, Optional,
|
|
5
|
+
from typing import Any, Type, Optional, Dict
|
|
6
6
|
from langchain_core.tools import BaseTool
|
|
7
|
-
from pydantic import BaseModel,
|
|
7
|
+
from pydantic import BaseModel, create_model
|
|
8
8
|
from pydantic.fields import FieldInfo
|
|
9
9
|
|
|
10
10
|
logger = logging.getLogger(__name__)
|
|
@@ -190,30 +190,28 @@ class PyodideSandboxTool(BaseTool):
|
|
|
190
190
|
self.session_bytes = result.session_bytes
|
|
191
191
|
self.session_metadata = result.session_metadata
|
|
192
192
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
193
|
+
result_dict = {}
|
|
194
|
+
|
|
196
195
|
if result.result is not None:
|
|
197
|
-
|
|
198
|
-
|
|
196
|
+
result_dict["result"] = result.result
|
|
197
|
+
|
|
199
198
|
if result.stdout:
|
|
200
|
-
|
|
201
|
-
|
|
199
|
+
result_dict["output"] = result.stdout
|
|
200
|
+
|
|
202
201
|
if result.stderr:
|
|
203
|
-
|
|
204
|
-
|
|
202
|
+
result_dict["error"] = result.stderr
|
|
203
|
+
|
|
205
204
|
if result.status == 'error':
|
|
206
|
-
|
|
207
|
-
|
|
205
|
+
result_dict["status"] = "Execution failed"
|
|
206
|
+
|
|
208
207
|
execution_info = f"Execution time: {result.execution_time:.2f}s"
|
|
209
208
|
if result.session_metadata and 'packages' in result.session_metadata:
|
|
210
209
|
packages = result.session_metadata.get('packages', [])
|
|
211
210
|
if packages:
|
|
212
211
|
execution_info += f", Packages: {', '.join(packages)}"
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
return "\n".join(output_parts) if output_parts else "Code executed successfully (no output)"
|
|
212
|
+
|
|
213
|
+
result_dict["execution_info"] = execution_info
|
|
214
|
+
return result_dict
|
|
217
215
|
|
|
218
216
|
except Exception as e:
|
|
219
217
|
logger.error(f"Error executing code in sandbox: {e}")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: alita_sdk
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.374
|
|
4
4
|
Summary: SDK for building langchain agents using resources from Alita
|
|
5
5
|
Author-email: Artem Rozumenko <artyom.rozumenko@gmail.com>, Mikalai Biazruchka <mikalai_biazruchka@epam.com>, Roman Mitusov <roman_mitusov@epam.com>, Ivan Krakhmaliuk <lifedj27@gmail.com>, Artem Dubrovskiy <ad13box@gmail.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -35,16 +35,17 @@ alita_sdk/configurations/zephyr_enterprise.py,sha256=UaBk3qWcT2-bCzko5HEPvgxArw1
|
|
|
35
35
|
alita_sdk/configurations/zephyr_essential.py,sha256=tUIrh-PRNvdrLBj6rJXqlF-h6oaMXUQI1wgit07kFBw,752
|
|
36
36
|
alita_sdk/runtime/__init__.py,sha256=4W0UF-nl3QF2bvET5lnah4o24CoTwSoKXhuN0YnwvEE,828
|
|
37
37
|
alita_sdk/runtime/clients/__init__.py,sha256=BdehU5GBztN1Qi1Wul0cqlU46FxUfMnI6Vq2Zd_oq1M,296
|
|
38
|
-
alita_sdk/runtime/clients/artifact.py,sha256=
|
|
38
|
+
alita_sdk/runtime/clients/artifact.py,sha256=b7hVuGRROt6qUcT11uAZqzJqslzmlgW-Y6oGsiwNmjI,4029
|
|
39
39
|
alita_sdk/runtime/clients/client.py,sha256=BIF6QSnhlTfsTQ_dQs-QZjeBJHZsOtSuv_q7_ABUUQg,45737
|
|
40
40
|
alita_sdk/runtime/clients/datasource.py,sha256=HAZovoQN9jBg0_-lIlGBQzb4FJdczPhkHehAiVG3Wx0,1020
|
|
41
41
|
alita_sdk/runtime/clients/prompt.py,sha256=li1RG9eBwgNK_Qf0qUaZ8QNTmsncFrAL2pv3kbxZRZg,1447
|
|
42
|
+
alita_sdk/runtime/clients/sandbox_client.py,sha256=OhEasE0MxBBDw4o76xkxVCpNpr3xJ8spQsrsVxMrjUA,16192
|
|
42
43
|
alita_sdk/runtime/langchain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
43
|
-
alita_sdk/runtime/langchain/assistant.py,sha256=
|
|
44
|
+
alita_sdk/runtime/langchain/assistant.py,sha256=YsxYNoaEidV02VlPwccdHP7PKeRRPp9M3tvUiYIDQ-I,15514
|
|
44
45
|
alita_sdk/runtime/langchain/chat_message_template.py,sha256=kPz8W2BG6IMyITFDA5oeb5BxVRkHEVZhuiGl4MBZKdc,2176
|
|
45
46
|
alita_sdk/runtime/langchain/constants.py,sha256=eHVJ_beJNTf1WJo4yq7KMK64fxsRvs3lKc34QCXSbpk,3319
|
|
46
47
|
alita_sdk/runtime/langchain/indexer.py,sha256=0ENHy5EOhThnAiYFc7QAsaTNp9rr8hDV_hTK8ahbatk,37592
|
|
47
|
-
alita_sdk/runtime/langchain/langraph_agent.py,sha256=
|
|
48
|
+
alita_sdk/runtime/langchain/langraph_agent.py,sha256=R4h_m_7NUgays7lt-F9WvKEOnGr1Yz7OgrmLMiGxurQ,48530
|
|
48
49
|
alita_sdk/runtime/langchain/mixedAgentParser.py,sha256=M256lvtsL3YtYflBCEp-rWKrKtcY1dJIyRGVv7KW9ME,2611
|
|
49
50
|
alita_sdk/runtime/langchain/mixedAgentRenderes.py,sha256=asBtKqm88QhZRILditjYICwFVKF5KfO38hu2O-WrSWE,5964
|
|
50
51
|
alita_sdk/runtime/langchain/store_manager.py,sha256=i8Fl11IXJhrBXq1F1ukEVln57B1IBe-tqSUvfUmBV4A,2218
|
|
@@ -109,7 +110,7 @@ alita_sdk/runtime/tools/application.py,sha256=z3vLZODs-_xEEnZFmGF0fKz1j3VtNJxqsA
|
|
|
109
110
|
alita_sdk/runtime/tools/artifact.py,sha256=u3szFwZqguHrPZ3tZJ7S_TiZl7cxlT3oHYd6zbdpRDE,13842
|
|
110
111
|
alita_sdk/runtime/tools/datasource.py,sha256=pvbaSfI-ThQQnjHG-QhYNSTYRnZB0rYtZFpjCfpzxYI,2443
|
|
111
112
|
alita_sdk/runtime/tools/echo.py,sha256=spw9eCweXzixJqHnZofHE1yWiSUa04L4VKycf3KCEaM,486
|
|
112
|
-
alita_sdk/runtime/tools/function.py,sha256=
|
|
113
|
+
alita_sdk/runtime/tools/function.py,sha256=0JL9D9NP31uzZ1G5br4Dhfop77l1wiqjx-7L8PHK4PA,6349
|
|
113
114
|
alita_sdk/runtime/tools/graph.py,sha256=MbnZYqdmvZY7SGDp43lOVVIjUt5ARHSgj43mdtBjSjQ,3092
|
|
114
115
|
alita_sdk/runtime/tools/image_generation.py,sha256=8ZH4SoRrbS4EzmtF6cpNMRvuFephCYD2S8uqNC9KGE4,4274
|
|
115
116
|
alita_sdk/runtime/tools/indexer_tool.py,sha256=whSLPevB4WD6dhh2JDXEivDmTvbjiMV1MrPl9cz5eLA,4375
|
|
@@ -120,7 +121,7 @@ alita_sdk/runtime/tools/mcp_server_tool.py,sha256=MhLxZJ44LYrB_0GrojmkyqKoDRaqIH
|
|
|
120
121
|
alita_sdk/runtime/tools/pgvector_search.py,sha256=NN2BGAnq4SsDHIhUcFZ8d_dbEOM8QwB0UwpsWCYruXU,11692
|
|
121
122
|
alita_sdk/runtime/tools/prompt.py,sha256=nJafb_e5aOM1Rr3qGFCR-SKziU9uCsiP2okIMs9PppM,741
|
|
122
123
|
alita_sdk/runtime/tools/router.py,sha256=p7e0tX6YAWw2M2Nq0A_xqw1E2P-Xz1DaJvhUstfoZn4,1584
|
|
123
|
-
alita_sdk/runtime/tools/sandbox.py,sha256=
|
|
124
|
+
alita_sdk/runtime/tools/sandbox.py,sha256=0OjCNsDVO1N0cFNEFVr6GVICSaqGWesUzF6LcYg-Hn0,11349
|
|
124
125
|
alita_sdk/runtime/tools/tool.py,sha256=lE1hGi6qOAXG7qxtqxarD_XMQqTghdywf261DZawwno,5631
|
|
125
126
|
alita_sdk/runtime/tools/vectorstore.py,sha256=8vRhi1lGFEs3unvnflEi2p59U2MfV32lStpEizpDms0,34467
|
|
126
127
|
alita_sdk/runtime/tools/vectorstore_base.py,sha256=wixvgLrC2tQOeIjFMCD-7869K7YfERzk2Tzmo-fgsTE,28350
|
|
@@ -352,8 +353,8 @@ alita_sdk/tools/zephyr_scale/api_wrapper.py,sha256=kT0TbmMvuKhDUZc0i7KO18O38JM9S
|
|
|
352
353
|
alita_sdk/tools/zephyr_squad/__init__.py,sha256=0ne8XLJEQSLOWfzd2HdnqOYmQlUliKHbBED5kW_Vias,2895
|
|
353
354
|
alita_sdk/tools/zephyr_squad/api_wrapper.py,sha256=kmw_xol8YIYFplBLWTqP_VKPRhL_1ItDD0_vXTe_UuI,14906
|
|
354
355
|
alita_sdk/tools/zephyr_squad/zephyr_squad_cloud_client.py,sha256=R371waHsms4sllHCbijKYs90C-9Yu0sSR3N4SUfQOgU,5066
|
|
355
|
-
alita_sdk-0.3.
|
|
356
|
-
alita_sdk-0.3.
|
|
357
|
-
alita_sdk-0.3.
|
|
358
|
-
alita_sdk-0.3.
|
|
359
|
-
alita_sdk-0.3.
|
|
356
|
+
alita_sdk-0.3.374.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
357
|
+
alita_sdk-0.3.374.dist-info/METADATA,sha256=b-L7XNDZ_LNpW-hoB_pDqOchYCdw9fOUStiXnQfSxUM,19071
|
|
358
|
+
alita_sdk-0.3.374.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
359
|
+
alita_sdk-0.3.374.dist-info/top_level.txt,sha256=0vJYy5p_jK6AwVb1aqXr7Kgqgk3WDtQ6t5C-XI9zkmg,10
|
|
360
|
+
alita_sdk-0.3.374.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|