alita-sdk 0.3.126__py3-none-any.whl → 0.3.128__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alita_sdk/langchain/langraph_agent.py +63 -6
- alita_sdk/toolkits/tools.py +2 -2
- alita_sdk/tools/agent.py +74 -0
- alita_sdk/tools/tool.py +8 -2
- {alita_sdk-0.3.126.dist-info → alita_sdk-0.3.128.dist-info}/METADATA +1 -1
- {alita_sdk-0.3.126.dist-info → alita_sdk-0.3.128.dist-info}/RECORD +9 -8
- {alita_sdk-0.3.126.dist-info → alita_sdk-0.3.128.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.126.dist-info → alita_sdk-0.3.128.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.126.dist-info → alita_sdk-0.3.128.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ from uuid import uuid4
|
|
4
4
|
from typing import Dict
|
5
5
|
|
6
6
|
import yaml
|
7
|
+
import ast
|
7
8
|
from langchain_core.callbacks import dispatch_custom_event
|
8
9
|
from langchain_core.messages import HumanMessage
|
9
10
|
from langchain_core.runnables import Runnable
|
@@ -211,6 +212,25 @@ class TransitionalEdge(Runnable):
|
|
211
212
|
)
|
212
213
|
return self.next_step if self.next_step != 'END' else END
|
213
214
|
|
215
|
+
class StateDefaultNode(Runnable):
|
216
|
+
name = "StateDefaultNode"
|
217
|
+
|
218
|
+
def __init__(self, default_vars: dict = {}):
|
219
|
+
self.default_vars = default_vars
|
220
|
+
|
221
|
+
def invoke(self, state: BaseStore, config: Optional[RunnableConfig] = None) -> dict:
|
222
|
+
logger.info("Setting default state variables")
|
223
|
+
result = {}
|
224
|
+
for key, value in self.default_vars.items():
|
225
|
+
if isinstance(value, dict) and 'value' in value:
|
226
|
+
temp_value = value['value']
|
227
|
+
try:
|
228
|
+
result[key] = ast.literal_eval(temp_value)
|
229
|
+
except:
|
230
|
+
logger.debug("Unable to evaluate value, using as is")
|
231
|
+
result[key] = temp_value
|
232
|
+
return result
|
233
|
+
|
214
234
|
|
215
235
|
class StateModifierNode(Runnable):
|
216
236
|
name = "StateModifierNode"
|
@@ -353,7 +373,8 @@ def create_graph(
|
|
353
373
|
logger.debug(f"Schema: {schema}")
|
354
374
|
logger.debug(f"Tools: {tools}")
|
355
375
|
logger.info(f"Tools: {[tool.name for tool in tools]}")
|
356
|
-
|
376
|
+
state = schema.get('state', {})
|
377
|
+
state_class = create_state(state)
|
357
378
|
lg_builder = StateGraph(state_class)
|
358
379
|
interrupt_before = [clean_string(every) for every in schema.get('interrupt_before', [])]
|
359
380
|
interrupt_after = [clean_string(every) for every in schema.get('interrupt_after', [])]
|
@@ -366,7 +387,7 @@ def create_graph(
|
|
366
387
|
if toolkit_name:
|
367
388
|
tool_name = f"{clean_string(toolkit_name)}{TOOLKIT_SPLITTER}{tool_name}"
|
368
389
|
logger.info(f"Node: {node_id} : {node_type} - {tool_name}")
|
369
|
-
if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph']:
|
390
|
+
if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph', 'pipeline', 'agent']:
|
370
391
|
for tool in tools:
|
371
392
|
if tool.name == tool_name:
|
372
393
|
if node_type == 'function':
|
@@ -376,14 +397,30 @@ def create_graph(
|
|
376
397
|
input_mapping=node.get('input_mapping',
|
377
398
|
{'messages': {'type': 'variable', 'value': 'messages'}}),
|
378
399
|
input_variables=node.get('input', ['messages'])))
|
379
|
-
elif node_type == '
|
400
|
+
elif node_type == 'agent':
|
401
|
+
input_params = node.get('input', ['messages'])
|
402
|
+
input_mapping = {'task': {'type': 'fstring', 'value': f"{node.get('task', '')}"}}
|
403
|
+
# Add 'chat_history' to input_mapping only if 'messages' is in input_params
|
404
|
+
if 'messages' in input_params:
|
405
|
+
input_mapping['chat_history'] = {'type': 'variable', 'value': 'messages'}
|
406
|
+
lg_builder.add_node(node_id, FunctionTool(
|
407
|
+
client=client, tool=tool,
|
408
|
+
name=node['id'], return_type='dict',
|
409
|
+
output_variables=node.get('output', []),
|
410
|
+
input_variables=input_params,
|
411
|
+
input_mapping= input_mapping
|
412
|
+
))
|
413
|
+
elif node_type == 'subgraph' or node_type == 'pipeline':
|
380
414
|
# assign parent memory/store
|
381
415
|
# tool.checkpointer = memory
|
382
416
|
# tool.store = store
|
383
417
|
# wrap with mappings
|
418
|
+
pipeline_name = node.get('tool', None)
|
419
|
+
if not pipeline_name:
|
420
|
+
raise ValueError("Subgraph must have a 'tool' node: add required tool to the subgraph node")
|
384
421
|
node_fn = SubgraphRunnable(
|
385
422
|
inner=tool,
|
386
|
-
name=
|
423
|
+
name=pipeline_name,
|
387
424
|
input_mapping=node.get('input_mapping', {}),
|
388
425
|
output_mapping=node.get('output_mapping', {}),
|
389
426
|
)
|
@@ -398,6 +435,15 @@ def create_graph(
|
|
398
435
|
structured_output=node.get('structured_output', False),
|
399
436
|
task=node.get('task')
|
400
437
|
))
|
438
|
+
# TODO: decide on struct output for agent nodes
|
439
|
+
# elif node_type == 'agent':
|
440
|
+
# lg_builder.add_node(node_id, AgentNode(
|
441
|
+
# client=client, tool=tool,
|
442
|
+
# name=node['id'], return_type='dict',
|
443
|
+
# output_variables=node.get('output', []),
|
444
|
+
# input_variables=node.get('input', ['messages']),
|
445
|
+
# task=node.get('task')
|
446
|
+
# ))
|
401
447
|
elif node_type == 'loop':
|
402
448
|
lg_builder.add_node(node_id, LoopNode(
|
403
449
|
client=client, tool=tool,
|
@@ -503,9 +549,20 @@ def create_graph(
|
|
503
549
|
conditional_outputs=node['condition'].get('conditional_outputs', []),
|
504
550
|
default_output=node['condition'].get('default_output', 'END')))
|
505
551
|
|
506
|
-
|
552
|
+
# set default value for state variable at START
|
553
|
+
entry_point = clean_string(schema['entry_point'])
|
554
|
+
for key, value in state.items():
|
555
|
+
if 'type' in value and 'value' in value:
|
556
|
+
# set default value for state variable if it is defined in the schema
|
557
|
+
state_default_node = StateDefaultNode(default_vars=state)
|
558
|
+
lg_builder.add_node(state_default_node.name, state_default_node)
|
559
|
+
lg_builder.set_entry_point(state_default_node.name)
|
560
|
+
lg_builder.add_conditional_edges(state_default_node.name, TransitionalEdge(entry_point))
|
561
|
+
break
|
562
|
+
else:
|
563
|
+
# if no state variables are defined, set the entry point directly
|
564
|
+
lg_builder.set_entry_point(entry_point)
|
507
565
|
|
508
|
-
# assign default values
|
509
566
|
interrupt_before = interrupt_before or []
|
510
567
|
interrupt_after = interrupt_after or []
|
511
568
|
|
alita_sdk/toolkits/tools.py
CHANGED
@@ -52,7 +52,7 @@ def get_tools(tools_list: list, alita_client, llm) -> list:
|
|
52
52
|
selected_tools=tool['settings']['selected_tools'],
|
53
53
|
toolkit_name=tool.get('toolkit_name', '') or tool.get('name', '')
|
54
54
|
).get_tools())
|
55
|
-
elif tool['type'] == 'application':
|
55
|
+
elif tool['type'] == 'application' and tool.get('agent_type', '') != 'pipeline' :
|
56
56
|
tools.extend(ApplicationToolkit.get_toolkit(
|
57
57
|
alita_client,
|
58
58
|
application_id=int(tool['settings']['application_id']),
|
@@ -60,7 +60,7 @@ def get_tools(tools_list: list, alita_client, llm) -> list:
|
|
60
60
|
app_api_key=alita_client.auth_token,
|
61
61
|
selected_tools=[]
|
62
62
|
).get_tools())
|
63
|
-
elif tool['type'] == '
|
63
|
+
elif tool['type'] == 'application' and tool.get('agent_type', '') == 'pipeline':
|
64
64
|
# static get_toolkit returns a list of CompiledStateGraph stubs
|
65
65
|
tools.extend(SubgraphToolkit.get_toolkit(
|
66
66
|
alita_client,
|
alita_sdk/tools/agent.py
ADDED
@@ -0,0 +1,74 @@
|
|
1
|
+
import logging
|
2
|
+
from json import dumps
|
3
|
+
from traceback import format_exc
|
4
|
+
from typing import Any, Optional, Union
|
5
|
+
|
6
|
+
from langchain_core.callbacks import dispatch_custom_event
|
7
|
+
from langchain_core.messages import ToolCall
|
8
|
+
from langchain_core.runnables import RunnableConfig
|
9
|
+
from langchain_core.tools import BaseTool
|
10
|
+
from langchain_core.utils.function_calling import convert_to_openai_tool
|
11
|
+
from pydantic import ValidationError
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class AgentNode(BaseTool):
|
17
|
+
name: str = 'AgentNode'
|
18
|
+
description: str = 'This is agent node for tools'
|
19
|
+
client: Any = None
|
20
|
+
tool: BaseTool = None
|
21
|
+
return_type: str = "str"
|
22
|
+
input_variables: Optional[list[str]] = None
|
23
|
+
output_variables: Optional[list[str]] = None
|
24
|
+
structured_output: Optional[bool] = False
|
25
|
+
task: Optional[str] = None
|
26
|
+
|
27
|
+
def invoke(
|
28
|
+
self,
|
29
|
+
state: Union[str, dict, ToolCall],
|
30
|
+
config: Optional[RunnableConfig] = None,
|
31
|
+
**kwargs: Any,
|
32
|
+
) -> Any:
|
33
|
+
params = convert_to_openai_tool(self.tool).get(
|
34
|
+
'function', {'parameters': {}}).get(
|
35
|
+
'parameters', {'properties': {}}).get('properties', {})
|
36
|
+
input_ = []
|
37
|
+
last_message = {}
|
38
|
+
logger.debug(f"AgentNode input: {self.input_variables}")
|
39
|
+
logger.debug(f"Output variables: {self.output_variables}")
|
40
|
+
for var in self.input_variables:
|
41
|
+
if var != 'messages':
|
42
|
+
last_message[var] = state[var]
|
43
|
+
if self.task:
|
44
|
+
task = self.task.format(**last_message, last_message=dumps(last_message))
|
45
|
+
else:
|
46
|
+
task = 'Input from user: {last_message}'.format(last_message=dumps(last_message))
|
47
|
+
try:
|
48
|
+
agent_input = {'task': task, 'chat_history': state.get('messages', [])[:] if 'messages' in self.input_variables else None}
|
49
|
+
tool_result = self.tool.invoke(agent_input, config=config, kwargs=kwargs)
|
50
|
+
dispatch_custom_event(
|
51
|
+
"on_tool_node", {
|
52
|
+
"input_variables": self.input_variables,
|
53
|
+
"tool_result": tool_result,
|
54
|
+
"state": state,
|
55
|
+
}, config=config
|
56
|
+
)
|
57
|
+
message_result = tool_result
|
58
|
+
if isinstance(tool_result, dict) or isinstance(tool_result, list):
|
59
|
+
message_result = dumps(tool_result)
|
60
|
+
logger.info(f"AgentNode response: {tool_result}")
|
61
|
+
if not self.output_variables:
|
62
|
+
return {"messages": [{"role": "assistant", "content": message_result}]}
|
63
|
+
else:
|
64
|
+
return {self.output_variables[0]: tool_result,
|
65
|
+
"messages": [{"role": "assistant", "content": message_result}]}
|
66
|
+
except ValidationError:
|
67
|
+
logger.error(f"ValidationError: {format_exc()}")
|
68
|
+
return {
|
69
|
+
"messages": [{"role": "assistant", "content": f"""Tool input to the {self.tool.name} with value {agent_input} raised ValidationError.
|
70
|
+
\n\nTool schema is {dumps(params)} \n\nand the input to LLM was
|
71
|
+
{input_[-1].content}"""}]}
|
72
|
+
|
73
|
+
def _run(self, *args, **kwargs):
|
74
|
+
return self.invoke(**kwargs)
|
alita_sdk/tools/tool.py
CHANGED
@@ -10,6 +10,7 @@ from langchain_core.tools import BaseTool
|
|
10
10
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
11
11
|
from pydantic import ValidationError, BaseModel, create_model
|
12
12
|
|
13
|
+
from .application import Application
|
13
14
|
from ..langchain.utils import _extract_json
|
14
15
|
|
15
16
|
logger = logging.getLogger(__name__)
|
@@ -74,8 +75,9 @@ Anwer must be JSON only extractable by JSON.LOADS."""
|
|
74
75
|
))
|
75
76
|
]
|
76
77
|
if self.structured_output:
|
77
|
-
# cut defaults from schema
|
78
|
-
fields = {name: (field.annotation, ...) for name, field
|
78
|
+
# cut defaults from schema and remove chat_history for application as a tool
|
79
|
+
fields = {name: (field.annotation, ...) for name, field
|
80
|
+
in self.tool.args_schema.model_fields.items() if name != 'chat_history'}
|
79
81
|
input_schema = create_model('NewModel', **fields)
|
80
82
|
|
81
83
|
llm = self.client.with_structured_output(input_schema)
|
@@ -87,6 +89,10 @@ Anwer must be JSON only extractable by JSON.LOADS."""
|
|
87
89
|
result = _extract_json(completion.content.strip())
|
88
90
|
logger.info(f"ToolNode tool params: {result}")
|
89
91
|
try:
|
92
|
+
# handler for application added as a tool
|
93
|
+
if isinstance(self.tool, Application):
|
94
|
+
# set empty chat history
|
95
|
+
result['chat_history'] = None
|
90
96
|
tool_result = self.tool.invoke(result, config=config, kwargs=kwargs)
|
91
97
|
dispatch_custom_event(
|
92
98
|
"on_tool_node", {
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: alita_sdk
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.128
|
4
4
|
Summary: SDK for building langchain agents using resouces from Alita
|
5
5
|
Author-email: Artem Rozumenko <artyom.rozumenko@gmail.com>, Mikalai Biazruchka <mikalai_biazruchka@epam.com>, Roman Mitusov <roman_mitusov@epam.com>, Ivan Krakhmaliuk <lifedjik@gmail.com>
|
6
6
|
Project-URL: Homepage, https://projectalita.ai
|
@@ -19,7 +19,7 @@ alita_sdk/langchain/assistant.py,sha256=J_xhwbNl934BgDKSpAMC9a1u6v03DZQcTYaamCzt
|
|
19
19
|
alita_sdk/langchain/chat_message_template.py,sha256=kPz8W2BG6IMyITFDA5oeb5BxVRkHEVZhuiGl4MBZKdc,2176
|
20
20
|
alita_sdk/langchain/constants.py,sha256=eHVJ_beJNTf1WJo4yq7KMK64fxsRvs3lKc34QCXSbpk,3319
|
21
21
|
alita_sdk/langchain/indexer.py,sha256=0ENHy5EOhThnAiYFc7QAsaTNp9rr8hDV_hTK8ahbatk,37592
|
22
|
-
alita_sdk/langchain/langraph_agent.py,sha256=
|
22
|
+
alita_sdk/langchain/langraph_agent.py,sha256=5TQQ1S2UhRY0PYCdr-W282LgPqLM2HA9xr8MOh185BY,39747
|
23
23
|
alita_sdk/langchain/mixedAgentParser.py,sha256=M256lvtsL3YtYflBCEp-rWKrKtcY1dJIyRGVv7KW9ME,2611
|
24
24
|
alita_sdk/langchain/mixedAgentRenderes.py,sha256=asBtKqm88QhZRILditjYICwFVKF5KfO38hu2O-WrSWE,5964
|
25
25
|
alita_sdk/langchain/utils.py,sha256=Npferkn10dvdksnKzLJLBI5bNGQyVWTBwqp3vQtUqmY,6631
|
@@ -69,9 +69,10 @@ alita_sdk/toolkits/artifact.py,sha256=7zb17vhJ3CigeTqvzQ4VNBsU5UOCJqAwz7fOJGMYqX
|
|
69
69
|
alita_sdk/toolkits/datasource.py,sha256=v3FQu8Gmvq7gAGAnFEbA8qofyUhh98rxgIjY6GHBfyI,2494
|
70
70
|
alita_sdk/toolkits/prompt.py,sha256=WIpTkkVYWqIqOWR_LlSWz3ug8uO9tm5jJ7aZYdiGRn0,1192
|
71
71
|
alita_sdk/toolkits/subgraph.py,sha256=ZYqI4yVLbEPAjCR8dpXbjbL2ipX598Hk3fL6AgaqFD4,1758
|
72
|
-
alita_sdk/toolkits/tools.py,sha256=
|
72
|
+
alita_sdk/toolkits/tools.py,sha256=eb4UFaSky0N42cTXHiJ9KojaReivcRNr2RM4UABQf8o,5928
|
73
73
|
alita_sdk/toolkits/vectorstore.py,sha256=di08-CRl0KJ9xSZ8_24VVnPZy58iLqHtXW8vuF29P64,2893
|
74
74
|
alita_sdk/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
75
|
+
alita_sdk/tools/agent.py,sha256=m98QxOHwnCRTT9j18Olbb5UPS8-ZGeQaGiUyZJSyFck,3162
|
75
76
|
alita_sdk/tools/application.py,sha256=UJlYd3Sub10LpAoKkKEpvd4miWyrS-yYE5NKyqx-H4Q,2194
|
76
77
|
alita_sdk/tools/artifact.py,sha256=uTa6K5d-NCDRnuLJVd6vA5TNIPH39onyPIyW5Thz4C0,6160
|
77
78
|
alita_sdk/tools/datasource.py,sha256=pvbaSfI-ThQQnjHG-QhYNSTYRnZB0rYtZFpjCfpzxYI,2443
|
@@ -85,7 +86,7 @@ alita_sdk/tools/mcp_server_tool.py,sha256=xcH9AiqfR2TYrwJ3Ixw-_A7XDodtJCnwmq1Ssi
|
|
85
86
|
alita_sdk/tools/pgvector_search.py,sha256=NN2BGAnq4SsDHIhUcFZ8d_dbEOM8QwB0UwpsWCYruXU,11692
|
86
87
|
alita_sdk/tools/prompt.py,sha256=nJafb_e5aOM1Rr3qGFCR-SKziU9uCsiP2okIMs9PppM,741
|
87
88
|
alita_sdk/tools/router.py,sha256=wCvZjVkdXK9dMMeEerrgKf5M790RudH68pDortnHSz0,1517
|
88
|
-
alita_sdk/tools/tool.py,sha256=
|
89
|
+
alita_sdk/tools/tool.py,sha256=f2ULDU4PU4PlLgygT_lsInLgNROJeWUNXLe0i0uOcqI,5419
|
89
90
|
alita_sdk/tools/vectorstore.py,sha256=F-DoHxPa4UVsKB-FEd-wWa59QGQifKMwcSNcZ5WZOKc,23496
|
90
91
|
alita_sdk/utils/AlitaCallback.py,sha256=cvpDhR4QLVCNQci6CO6TEUrUVDZU9_CRSwzcHGm3SGw,7356
|
91
92
|
alita_sdk/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -93,10 +94,10 @@ alita_sdk/utils/evaluate.py,sha256=iM1P8gzBLHTuSCe85_Ng_h30m52hFuGuhNXJ7kB1tgI,1
|
|
93
94
|
alita_sdk/utils/logging.py,sha256=hBE3qAzmcLMdamMp2YRXwOOK9P4lmNaNhM76kntVljs,3124
|
94
95
|
alita_sdk/utils/streamlit.py,sha256=zp8owZwHI3HZplhcExJf6R3-APtWx-z6s5jznT2hY_k,29124
|
95
96
|
alita_sdk/utils/utils.py,sha256=dM8whOJAuFJFe19qJ69-FLzrUp6d2G-G6L7d4ss2XqM,346
|
96
|
-
alita_sdk-0.3.
|
97
|
+
alita_sdk-0.3.128.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
97
98
|
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
98
99
|
tests/test_jira_analysis.py,sha256=I0cErH5R_dHVyutpXrM1QEo7jfBuKWTmDQvJBPjx18I,3281
|
99
|
-
alita_sdk-0.3.
|
100
|
-
alita_sdk-0.3.
|
101
|
-
alita_sdk-0.3.
|
102
|
-
alita_sdk-0.3.
|
100
|
+
alita_sdk-0.3.128.dist-info/METADATA,sha256=zWqqG-EN_cA9WGMHZvsejLGCpSTb1rrT2x2rdHugVoE,7075
|
101
|
+
alita_sdk-0.3.128.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
102
|
+
alita_sdk-0.3.128.dist-info/top_level.txt,sha256=SWRhxB7Et3cOy3RkE5hR7OIRnHoo3K8EXzoiNlkfOmc,25
|
103
|
+
alita_sdk-0.3.128.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|