alita-sdk 0.3.126__py3-none-any.whl → 0.3.128__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,7 @@ from uuid import uuid4
4
4
  from typing import Dict
5
5
 
6
6
  import yaml
7
+ import ast
7
8
  from langchain_core.callbacks import dispatch_custom_event
8
9
  from langchain_core.messages import HumanMessage
9
10
  from langchain_core.runnables import Runnable
@@ -211,6 +212,25 @@ class TransitionalEdge(Runnable):
211
212
  )
212
213
  return self.next_step if self.next_step != 'END' else END
213
214
 
215
+ class StateDefaultNode(Runnable):
216
+ name = "StateDefaultNode"
217
+
218
+ def __init__(self, default_vars: dict = {}):
219
+ self.default_vars = default_vars
220
+
221
+ def invoke(self, state: BaseStore, config: Optional[RunnableConfig] = None) -> dict:
222
+ logger.info("Setting default state variables")
223
+ result = {}
224
+ for key, value in self.default_vars.items():
225
+ if isinstance(value, dict) and 'value' in value:
226
+ temp_value = value['value']
227
+ try:
228
+ result[key] = ast.literal_eval(temp_value)
229
+ except:
230
+ logger.debug("Unable to evaluate value, using as is")
231
+ result[key] = temp_value
232
+ return result
233
+
214
234
 
215
235
  class StateModifierNode(Runnable):
216
236
  name = "StateModifierNode"
@@ -353,7 +373,8 @@ def create_graph(
353
373
  logger.debug(f"Schema: {schema}")
354
374
  logger.debug(f"Tools: {tools}")
355
375
  logger.info(f"Tools: {[tool.name for tool in tools]}")
356
- state_class = create_state(schema.get('state', {}))
376
+ state = schema.get('state', {})
377
+ state_class = create_state(state)
357
378
  lg_builder = StateGraph(state_class)
358
379
  interrupt_before = [clean_string(every) for every in schema.get('interrupt_before', [])]
359
380
  interrupt_after = [clean_string(every) for every in schema.get('interrupt_after', [])]
@@ -366,7 +387,7 @@ def create_graph(
366
387
  if toolkit_name:
367
388
  tool_name = f"{clean_string(toolkit_name)}{TOOLKIT_SPLITTER}{tool_name}"
368
389
  logger.info(f"Node: {node_id} : {node_type} - {tool_name}")
369
- if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph']:
390
+ if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph', 'pipeline', 'agent']:
370
391
  for tool in tools:
371
392
  if tool.name == tool_name:
372
393
  if node_type == 'function':
@@ -376,14 +397,30 @@ def create_graph(
376
397
  input_mapping=node.get('input_mapping',
377
398
  {'messages': {'type': 'variable', 'value': 'messages'}}),
378
399
  input_variables=node.get('input', ['messages'])))
379
- elif node_type == 'subgraph':
400
+ elif node_type == 'agent':
401
+ input_params = node.get('input', ['messages'])
402
+ input_mapping = {'task': {'type': 'fstring', 'value': f"{node.get('task', '')}"}}
403
+ # Add 'chat_history' to input_mapping only if 'messages' is in input_params
404
+ if 'messages' in input_params:
405
+ input_mapping['chat_history'] = {'type': 'variable', 'value': 'messages'}
406
+ lg_builder.add_node(node_id, FunctionTool(
407
+ client=client, tool=tool,
408
+ name=node['id'], return_type='dict',
409
+ output_variables=node.get('output', []),
410
+ input_variables=input_params,
411
+ input_mapping= input_mapping
412
+ ))
413
+ elif node_type == 'subgraph' or node_type == 'pipeline':
380
414
  # assign parent memory/store
381
415
  # tool.checkpointer = memory
382
416
  # tool.store = store
383
417
  # wrap with mappings
418
+ pipeline_name = node.get('tool', None)
419
+ if not pipeline_name:
420
+ raise ValueError("Subgraph must have a 'tool' node: add required tool to the subgraph node")
384
421
  node_fn = SubgraphRunnable(
385
422
  inner=tool,
386
- name=node['id'],
423
+ name=pipeline_name,
387
424
  input_mapping=node.get('input_mapping', {}),
388
425
  output_mapping=node.get('output_mapping', {}),
389
426
  )
@@ -398,6 +435,15 @@ def create_graph(
398
435
  structured_output=node.get('structured_output', False),
399
436
  task=node.get('task')
400
437
  ))
438
+ # TODO: decide on struct output for agent nodes
439
+ # elif node_type == 'agent':
440
+ # lg_builder.add_node(node_id, AgentNode(
441
+ # client=client, tool=tool,
442
+ # name=node['id'], return_type='dict',
443
+ # output_variables=node.get('output', []),
444
+ # input_variables=node.get('input', ['messages']),
445
+ # task=node.get('task')
446
+ # ))
401
447
  elif node_type == 'loop':
402
448
  lg_builder.add_node(node_id, LoopNode(
403
449
  client=client, tool=tool,
@@ -503,9 +549,20 @@ def create_graph(
503
549
  conditional_outputs=node['condition'].get('conditional_outputs', []),
504
550
  default_output=node['condition'].get('default_output', 'END')))
505
551
 
506
- lg_builder.set_entry_point(clean_string(schema['entry_point']))
552
+ # set default value for state variable at START
553
+ entry_point = clean_string(schema['entry_point'])
554
+ for key, value in state.items():
555
+ if 'type' in value and 'value' in value:
556
+ # set default value for state variable if it is defined in the schema
557
+ state_default_node = StateDefaultNode(default_vars=state)
558
+ lg_builder.add_node(state_default_node.name, state_default_node)
559
+ lg_builder.set_entry_point(state_default_node.name)
560
+ lg_builder.add_conditional_edges(state_default_node.name, TransitionalEdge(entry_point))
561
+ break
562
+ else:
563
+ # if no state variables are defined, set the entry point directly
564
+ lg_builder.set_entry_point(entry_point)
507
565
 
508
- # assign default values
509
566
  interrupt_before = interrupt_before or []
510
567
  interrupt_after = interrupt_after or []
511
568
 
@@ -52,7 +52,7 @@ def get_tools(tools_list: list, alita_client, llm) -> list:
52
52
  selected_tools=tool['settings']['selected_tools'],
53
53
  toolkit_name=tool.get('toolkit_name', '') or tool.get('name', '')
54
54
  ).get_tools())
55
- elif tool['type'] == 'application':
55
+ elif tool['type'] == 'application' and tool.get('agent_type', '') != 'pipeline' :
56
56
  tools.extend(ApplicationToolkit.get_toolkit(
57
57
  alita_client,
58
58
  application_id=int(tool['settings']['application_id']),
@@ -60,7 +60,7 @@ def get_tools(tools_list: list, alita_client, llm) -> list:
60
60
  app_api_key=alita_client.auth_token,
61
61
  selected_tools=[]
62
62
  ).get_tools())
63
- elif tool['type'] == 'subgraph':
63
+ elif tool['type'] == 'application' and tool.get('agent_type', '') == 'pipeline':
64
64
  # static get_toolkit returns a list of CompiledStateGraph stubs
65
65
  tools.extend(SubgraphToolkit.get_toolkit(
66
66
  alita_client,
@@ -0,0 +1,74 @@
1
+ import logging
2
+ from json import dumps
3
+ from traceback import format_exc
4
+ from typing import Any, Optional, Union
5
+
6
+ from langchain_core.callbacks import dispatch_custom_event
7
+ from langchain_core.messages import ToolCall
8
+ from langchain_core.runnables import RunnableConfig
9
+ from langchain_core.tools import BaseTool
10
+ from langchain_core.utils.function_calling import convert_to_openai_tool
11
+ from pydantic import ValidationError
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class AgentNode(BaseTool):
17
+ name: str = 'AgentNode'
18
+ description: str = 'This is agent node for tools'
19
+ client: Any = None
20
+ tool: BaseTool = None
21
+ return_type: str = "str"
22
+ input_variables: Optional[list[str]] = None
23
+ output_variables: Optional[list[str]] = None
24
+ structured_output: Optional[bool] = False
25
+ task: Optional[str] = None
26
+
27
+ def invoke(
28
+ self,
29
+ state: Union[str, dict, ToolCall],
30
+ config: Optional[RunnableConfig] = None,
31
+ **kwargs: Any,
32
+ ) -> Any:
33
+ params = convert_to_openai_tool(self.tool).get(
34
+ 'function', {'parameters': {}}).get(
35
+ 'parameters', {'properties': {}}).get('properties', {})
36
+ input_ = []
37
+ last_message = {}
38
+ logger.debug(f"AgentNode input: {self.input_variables}")
39
+ logger.debug(f"Output variables: {self.output_variables}")
40
+ for var in self.input_variables:
41
+ if var != 'messages':
42
+ last_message[var] = state[var]
43
+ if self.task:
44
+ task = self.task.format(**last_message, last_message=dumps(last_message))
45
+ else:
46
+ task = 'Input from user: {last_message}'.format(last_message=dumps(last_message))
47
+ try:
48
+ agent_input = {'task': task, 'chat_history': state.get('messages', [])[:] if 'messages' in self.input_variables else None}
49
+ tool_result = self.tool.invoke(agent_input, config=config, kwargs=kwargs)
50
+ dispatch_custom_event(
51
+ "on_tool_node", {
52
+ "input_variables": self.input_variables,
53
+ "tool_result": tool_result,
54
+ "state": state,
55
+ }, config=config
56
+ )
57
+ message_result = tool_result
58
+ if isinstance(tool_result, dict) or isinstance(tool_result, list):
59
+ message_result = dumps(tool_result)
60
+ logger.info(f"AgentNode response: {tool_result}")
61
+ if not self.output_variables:
62
+ return {"messages": [{"role": "assistant", "content": message_result}]}
63
+ else:
64
+ return {self.output_variables[0]: tool_result,
65
+ "messages": [{"role": "assistant", "content": message_result}]}
66
+ except ValidationError:
67
+ logger.error(f"ValidationError: {format_exc()}")
68
+ return {
69
+ "messages": [{"role": "assistant", "content": f"""Tool input to the {self.tool.name} with value {agent_input} raised ValidationError.
70
+ \n\nTool schema is {dumps(params)} \n\nand the input to LLM was
71
+ {input_[-1].content}"""}]}
72
+
73
+ def _run(self, *args, **kwargs):
74
+ return self.invoke(**kwargs)
alita_sdk/tools/tool.py CHANGED
@@ -10,6 +10,7 @@ from langchain_core.tools import BaseTool
10
10
  from langchain_core.utils.function_calling import convert_to_openai_tool
11
11
  from pydantic import ValidationError, BaseModel, create_model
12
12
 
13
+ from .application import Application
13
14
  from ..langchain.utils import _extract_json
14
15
 
15
16
  logger = logging.getLogger(__name__)
@@ -74,8 +75,9 @@ Anwer must be JSON only extractable by JSON.LOADS."""
74
75
  ))
75
76
  ]
76
77
  if self.structured_output:
77
- # cut defaults from schema
78
- fields = {name: (field.annotation, ...) for name, field in self.tool.args_schema.model_fields.items()}
78
+ # cut defaults from schema and remove chat_history for application as a tool
79
+ fields = {name: (field.annotation, ...) for name, field
80
+ in self.tool.args_schema.model_fields.items() if name != 'chat_history'}
79
81
  input_schema = create_model('NewModel', **fields)
80
82
 
81
83
  llm = self.client.with_structured_output(input_schema)
@@ -87,6 +89,10 @@ Anwer must be JSON only extractable by JSON.LOADS."""
87
89
  result = _extract_json(completion.content.strip())
88
90
  logger.info(f"ToolNode tool params: {result}")
89
91
  try:
92
+ # handler for application added as a tool
93
+ if isinstance(self.tool, Application):
94
+ # set empty chat history
95
+ result['chat_history'] = None
90
96
  tool_result = self.tool.invoke(result, config=config, kwargs=kwargs)
91
97
  dispatch_custom_event(
92
98
  "on_tool_node", {
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alita_sdk
3
- Version: 0.3.126
3
+ Version: 0.3.128
4
4
  Summary: SDK for building langchain agents using resouces from Alita
5
5
  Author-email: Artem Rozumenko <artyom.rozumenko@gmail.com>, Mikalai Biazruchka <mikalai_biazruchka@epam.com>, Roman Mitusov <roman_mitusov@epam.com>, Ivan Krakhmaliuk <lifedjik@gmail.com>
6
6
  Project-URL: Homepage, https://projectalita.ai
@@ -19,7 +19,7 @@ alita_sdk/langchain/assistant.py,sha256=J_xhwbNl934BgDKSpAMC9a1u6v03DZQcTYaamCzt
19
19
  alita_sdk/langchain/chat_message_template.py,sha256=kPz8W2BG6IMyITFDA5oeb5BxVRkHEVZhuiGl4MBZKdc,2176
20
20
  alita_sdk/langchain/constants.py,sha256=eHVJ_beJNTf1WJo4yq7KMK64fxsRvs3lKc34QCXSbpk,3319
21
21
  alita_sdk/langchain/indexer.py,sha256=0ENHy5EOhThnAiYFc7QAsaTNp9rr8hDV_hTK8ahbatk,37592
22
- alita_sdk/langchain/langraph_agent.py,sha256=HwopuxCWDOg6i-ZKbxZzrqnRZ84pGIS7kVN349ER8bs,36510
22
+ alita_sdk/langchain/langraph_agent.py,sha256=5TQQ1S2UhRY0PYCdr-W282LgPqLM2HA9xr8MOh185BY,39747
23
23
  alita_sdk/langchain/mixedAgentParser.py,sha256=M256lvtsL3YtYflBCEp-rWKrKtcY1dJIyRGVv7KW9ME,2611
24
24
  alita_sdk/langchain/mixedAgentRenderes.py,sha256=asBtKqm88QhZRILditjYICwFVKF5KfO38hu2O-WrSWE,5964
25
25
  alita_sdk/langchain/utils.py,sha256=Npferkn10dvdksnKzLJLBI5bNGQyVWTBwqp3vQtUqmY,6631
@@ -69,9 +69,10 @@ alita_sdk/toolkits/artifact.py,sha256=7zb17vhJ3CigeTqvzQ4VNBsU5UOCJqAwz7fOJGMYqX
69
69
  alita_sdk/toolkits/datasource.py,sha256=v3FQu8Gmvq7gAGAnFEbA8qofyUhh98rxgIjY6GHBfyI,2494
70
70
  alita_sdk/toolkits/prompt.py,sha256=WIpTkkVYWqIqOWR_LlSWz3ug8uO9tm5jJ7aZYdiGRn0,1192
71
71
  alita_sdk/toolkits/subgraph.py,sha256=ZYqI4yVLbEPAjCR8dpXbjbL2ipX598Hk3fL6AgaqFD4,1758
72
- alita_sdk/toolkits/tools.py,sha256=gk3nvQBdab3QM8v93ff2nrN4ZfcT779yae2RygkTl8s,5834
72
+ alita_sdk/toolkits/tools.py,sha256=eb4UFaSky0N42cTXHiJ9KojaReivcRNr2RM4UABQf8o,5928
73
73
  alita_sdk/toolkits/vectorstore.py,sha256=di08-CRl0KJ9xSZ8_24VVnPZy58iLqHtXW8vuF29P64,2893
74
74
  alita_sdk/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
75
+ alita_sdk/tools/agent.py,sha256=m98QxOHwnCRTT9j18Olbb5UPS8-ZGeQaGiUyZJSyFck,3162
75
76
  alita_sdk/tools/application.py,sha256=UJlYd3Sub10LpAoKkKEpvd4miWyrS-yYE5NKyqx-H4Q,2194
76
77
  alita_sdk/tools/artifact.py,sha256=uTa6K5d-NCDRnuLJVd6vA5TNIPH39onyPIyW5Thz4C0,6160
77
78
  alita_sdk/tools/datasource.py,sha256=pvbaSfI-ThQQnjHG-QhYNSTYRnZB0rYtZFpjCfpzxYI,2443
@@ -85,7 +86,7 @@ alita_sdk/tools/mcp_server_tool.py,sha256=xcH9AiqfR2TYrwJ3Ixw-_A7XDodtJCnwmq1Ssi
85
86
  alita_sdk/tools/pgvector_search.py,sha256=NN2BGAnq4SsDHIhUcFZ8d_dbEOM8QwB0UwpsWCYruXU,11692
86
87
  alita_sdk/tools/prompt.py,sha256=nJafb_e5aOM1Rr3qGFCR-SKziU9uCsiP2okIMs9PppM,741
87
88
  alita_sdk/tools/router.py,sha256=wCvZjVkdXK9dMMeEerrgKf5M790RudH68pDortnHSz0,1517
88
- alita_sdk/tools/tool.py,sha256=jFRq8BeC55NwpgdpsqGk_Y3tZL4YKN0rE7sVS5OE3yg,5092
89
+ alita_sdk/tools/tool.py,sha256=f2ULDU4PU4PlLgygT_lsInLgNROJeWUNXLe0i0uOcqI,5419
89
90
  alita_sdk/tools/vectorstore.py,sha256=F-DoHxPa4UVsKB-FEd-wWa59QGQifKMwcSNcZ5WZOKc,23496
90
91
  alita_sdk/utils/AlitaCallback.py,sha256=cvpDhR4QLVCNQci6CO6TEUrUVDZU9_CRSwzcHGm3SGw,7356
91
92
  alita_sdk/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -93,10 +94,10 @@ alita_sdk/utils/evaluate.py,sha256=iM1P8gzBLHTuSCe85_Ng_h30m52hFuGuhNXJ7kB1tgI,1
93
94
  alita_sdk/utils/logging.py,sha256=hBE3qAzmcLMdamMp2YRXwOOK9P4lmNaNhM76kntVljs,3124
94
95
  alita_sdk/utils/streamlit.py,sha256=zp8owZwHI3HZplhcExJf6R3-APtWx-z6s5jznT2hY_k,29124
95
96
  alita_sdk/utils/utils.py,sha256=dM8whOJAuFJFe19qJ69-FLzrUp6d2G-G6L7d4ss2XqM,346
96
- alita_sdk-0.3.126.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
97
+ alita_sdk-0.3.128.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
97
98
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
98
99
  tests/test_jira_analysis.py,sha256=I0cErH5R_dHVyutpXrM1QEo7jfBuKWTmDQvJBPjx18I,3281
99
- alita_sdk-0.3.126.dist-info/METADATA,sha256=eX2afqBm4mw5_LMMdJ_HXMXHTp93O8bWSIQ3jCVy8go,7075
100
- alita_sdk-0.3.126.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
101
- alita_sdk-0.3.126.dist-info/top_level.txt,sha256=SWRhxB7Et3cOy3RkE5hR7OIRnHoo3K8EXzoiNlkfOmc,25
102
- alita_sdk-0.3.126.dist-info/RECORD,,
100
+ alita_sdk-0.3.128.dist-info/METADATA,sha256=zWqqG-EN_cA9WGMHZvsejLGCpSTb1rrT2x2rdHugVoE,7075
101
+ alita_sdk-0.3.128.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
102
+ alita_sdk-0.3.128.dist-info/top_level.txt,sha256=SWRhxB7Et3cOy3RkE5hR7OIRnHoo3K8EXzoiNlkfOmc,25
103
+ alita_sdk-0.3.128.dist-info/RECORD,,