bisheng-langchain 0.2.3.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. bisheng_langchain/agents/llm_functions_agent/base.py +1 -1
  2. bisheng_langchain/chains/__init__.py +2 -1
  3. bisheng_langchain/chains/transform.py +85 -0
  4. bisheng_langchain/chat_models/host_llm.py +19 -5
  5. bisheng_langchain/chat_models/qwen.py +29 -8
  6. bisheng_langchain/document_loaders/custom_kv.py +1 -1
  7. bisheng_langchain/embeddings/host_embedding.py +9 -11
  8. bisheng_langchain/gpts/__init__.py +0 -0
  9. bisheng_langchain/gpts/agent_types/__init__.py +10 -0
  10. bisheng_langchain/gpts/agent_types/llm_functions_agent.py +220 -0
  11. bisheng_langchain/gpts/assistant.py +137 -0
  12. bisheng_langchain/gpts/auto_optimization.py +130 -0
  13. bisheng_langchain/gpts/auto_tool_selected.py +54 -0
  14. bisheng_langchain/gpts/load_tools.py +161 -0
  15. bisheng_langchain/gpts/message_types.py +11 -0
  16. bisheng_langchain/gpts/prompts/__init__.py +15 -0
  17. bisheng_langchain/gpts/prompts/assistant_prompt_opt.py +95 -0
  18. bisheng_langchain/gpts/prompts/base_prompt.py +1 -0
  19. bisheng_langchain/gpts/prompts/breif_description_prompt.py +104 -0
  20. bisheng_langchain/gpts/prompts/opening_dialog_prompt.py +118 -0
  21. bisheng_langchain/gpts/prompts/select_tools_prompt.py +29 -0
  22. bisheng_langchain/gpts/tools/__init__.py +0 -0
  23. bisheng_langchain/gpts/tools/api_tools/__init__.py +50 -0
  24. bisheng_langchain/gpts/tools/api_tools/base.py +90 -0
  25. bisheng_langchain/gpts/tools/api_tools/flow.py +59 -0
  26. bisheng_langchain/gpts/tools/api_tools/macro_data.py +397 -0
  27. bisheng_langchain/gpts/tools/api_tools/sina.py +221 -0
  28. bisheng_langchain/gpts/tools/api_tools/tianyancha.py +160 -0
  29. bisheng_langchain/gpts/tools/bing_search/__init__.py +0 -0
  30. bisheng_langchain/gpts/tools/bing_search/tool.py +55 -0
  31. bisheng_langchain/gpts/tools/calculator/__init__.py +0 -0
  32. bisheng_langchain/gpts/tools/calculator/tool.py +25 -0
  33. bisheng_langchain/gpts/tools/code_interpreter/__init__.py +0 -0
  34. bisheng_langchain/gpts/tools/code_interpreter/tool.py +261 -0
  35. bisheng_langchain/gpts/tools/dalle_image_generator/__init__.py +0 -0
  36. bisheng_langchain/gpts/tools/dalle_image_generator/tool.py +181 -0
  37. bisheng_langchain/gpts/tools/get_current_time/__init__.py +0 -0
  38. bisheng_langchain/gpts/tools/get_current_time/tool.py +23 -0
  39. bisheng_langchain/gpts/utils.py +197 -0
  40. bisheng_langchain/utils/requests.py +5 -1
  41. bisheng_langchain/vectorstores/milvus.py +1 -1
  42. {bisheng_langchain-0.2.3.1.dist-info → bisheng_langchain-0.3.0.dist-info}/METADATA +5 -2
  43. {bisheng_langchain-0.2.3.1.dist-info → bisheng_langchain-0.3.0.dist-info}/RECORD +45 -12
  44. {bisheng_langchain-0.2.3.1.dist-info → bisheng_langchain-0.3.0.dist-info}/WHEEL +0 -0
  45. {bisheng_langchain-0.2.3.1.dist-info → bisheng_langchain-0.3.0.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,6 @@ from bisheng_langchain.chat_models.proxy_llm import ProxyChatLLM
8
8
  from langchain.agents import BaseSingleActionAgent
9
9
  from langchain.callbacks.base import BaseCallbackManager
10
10
  from langchain.callbacks.manager import Callbacks
11
- from langchain.chat_models.openai import ChatOpenAI
12
11
  from langchain.prompts.chat import (BaseMessagePromptTemplate, ChatPromptTemplate,
13
12
  HumanMessagePromptTemplate, MessagesPlaceholder)
14
13
  from langchain.schema import AgentAction, AgentFinish, BasePromptTemplate, OutputParserException
@@ -18,6 +17,7 @@ from langchain.tools import BaseTool
18
17
  from langchain.tools.convert_to_openai import format_tool_to_openai_function
19
18
  from langchain_core.agents import AgentActionMessageLog
20
19
  from langchain_core.pydantic_v1 import root_validator
20
+ from langchain_openai import ChatOpenAI
21
21
 
22
22
 
23
23
  def _convert_agent_action_to_messages(agent_action: AgentAction,
@@ -4,10 +4,11 @@ from bisheng_langchain.chains.conversational_retrieval.base import Conversationa
4
4
  from bisheng_langchain.chains.retrieval.retrieval_chain import RetrievalChain
5
5
  from bisheng_langchain.chains.router.multi_rule import MultiRuleChain
6
6
  from bisheng_langchain.chains.router.rule_router import RuleBasedRouter
7
+ from bisheng_langchain.chains.transform import TransformChain
7
8
 
8
9
  from .loader_output import LoaderOutputChain
9
10
 
10
11
  __all__ = [
11
12
  'StuffDocumentsChain', 'LoaderOutputChain', 'AutoGenChain', 'RuleBasedRouter',
12
- 'MultiRuleChain', 'RetrievalChain', 'ConversationalRetrievalChain'
13
+ 'MultiRuleChain', 'RetrievalChain', 'ConversationalRetrievalChain', 'TransformChain'
13
14
  ]
@@ -0,0 +1,85 @@
1
+ """Chain that runs an arbitrary python function."""
2
+ import functools
3
+ import inspect
4
+ import logging
5
+ from typing import Any, Awaitable, Callable, Dict, List, Optional
6
+
7
+ from langchain.chains.base import Chain
8
+ from langchain_core.callbacks import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
9
+ from langchain_core.pydantic_v1 import Field
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class TransformChain(Chain):
15
+ """Chain that transforms the chain output.
16
+
17
+ Example:
18
+ .. code-block:: python
19
+
20
+ from langchain.chains import TransformChain
21
+ transform_chain = TransformChain(input_variables=["text"],
22
+ output_variables["entities"], transform=func())
23
+ """
24
+
25
+ input_variables: List[str]
26
+ """The keys expected by the transform's input dictionary."""
27
+ output_variables: List[str]
28
+ """The keys returned by the transform's output dictionary."""
29
+ transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias='transform')
30
+ """The transform function."""
31
+ atransform_cb: Optional[Callable[[Dict[str, Any]],
32
+ Awaitable[Dict[str, Any]]]] = Field(None, alias='atransform')
33
+ """The async coroutine transform function."""
34
+
35
+ @staticmethod
36
+ @functools.lru_cache
37
+ def _log_once(msg: str) -> None:
38
+ """Log a message once.
39
+
40
+ :meta private:
41
+ """
42
+ logger.warning(msg)
43
+
44
+ @property
45
+ def input_keys(self) -> List[str]:
46
+ """Expect input keys.
47
+
48
+ :meta private:
49
+ """
50
+ return self.input_variables
51
+
52
+ @property
53
+ def output_keys(self) -> List[str]:
54
+ """Return output keys.
55
+
56
+ :meta private:
57
+ """
58
+ return self.output_variables
59
+
60
+ def _call(
61
+ self,
62
+ inputs: Dict[str, str],
63
+ run_manager: Optional[CallbackManagerForChainRun] = None,
64
+ ) -> Dict[str, str]:
65
+ new_arg_supported = inspect.signature(self.transform_cb).parameters.get('run_manager')
66
+ if new_arg_supported:
67
+ return self.transform_cb(inputs, run_manager)
68
+ else:
69
+ return self.transform_cb(inputs)
70
+
71
+ async def _acall(
72
+ self,
73
+ inputs: Dict[str, Any],
74
+ run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
75
+ ) -> Dict[str, Any]:
76
+ if self.atransform_cb is not None:
77
+ new_arg_supported = inspect.signature(self.transform_cb).parameters.get('run_manager')
78
+ if new_arg_supported:
79
+ return await self.atransform_cb(inputs, run_manager)
80
+ else:
81
+ return await self.atransform_cb(inputs)
82
+ else:
83
+ self._log_once("TransformChain's atransform is not provided, falling"
84
+ ' back to synchronous transform')
85
+ return self._call(inputs, run_manager)
@@ -200,6 +200,7 @@ class BaseHostChatLLM(BaseChatModel):
200
200
  max_tokens = kwargs.get('max_tokens')
201
201
  do_sample = kwargs.get('do_sample')
202
202
  params = {
203
+ 'stream': False,
203
204
  'messages': messages,
204
205
  'model': self.model_name,
205
206
  'top_p': top_p,
@@ -285,6 +286,7 @@ class BaseHostChatLLM(BaseChatModel):
285
286
  except Exception as e:
286
287
  raise ValueError(f'exception in host llm infer: [{e}]') from e
287
288
 
289
+ text_haf = ''
288
290
  async for response in _acompletion_with_retry(**kwargs):
289
291
  is_error = False
290
292
  if response:
@@ -292,10 +294,19 @@ class BaseHostChatLLM(BaseChatModel):
292
294
  is_error = True
293
295
  elif response.startswith('data:'):
294
296
  text = response[len('data:'):].strip()
295
- if text.startswith('{'):
296
- yield (is_error, response[len('data:'):])
297
- else:
298
- logger.info('agenerate_no_json text=%s', text)
297
+ if text == '[DONE]':
298
+ break
299
+ try:
300
+ json.loads(text_haf + text)
301
+ yield (is_error, text_haf + text)
302
+ text_haf = ''
303
+ except Exception:
304
+ # 拆包了
305
+ if text_haf.startswith('{'):
306
+ text_haf = text
307
+ continue
308
+ logger.error(f'response_not_json response={response}')
309
+
299
310
  if is_error:
300
311
  break
301
312
  elif response.startswith('{'):
@@ -521,6 +532,7 @@ class HostQwen1_5Chat(BaseHostChatLLM):
521
532
  """Return type of chat model."""
522
533
  return 'qwen1.5_chat'
523
534
 
535
+
524
536
  class HostLlama2Chat(BaseHostChatLLM):
525
537
  # Llama-2-7b-chat-hf, Llama-2-13b-chat-hf, Llama-2-70b-chat-hf
526
538
  model_name: str = Field('Llama-2-7b-chat-hf', alias='model')
@@ -549,6 +561,7 @@ class CustomLLMChat(BaseHostChatLLM):
549
561
  """Return type of chat model."""
550
562
  return 'custom_llm_chat'
551
563
 
564
+
552
565
  class HostYuanChat(BaseHostChatLLM):
553
566
  # use custom llm chat api, api should compatiable with openai definition
554
567
  model_name: str = Field('Yuan2-2B-Janus-hf', alias='model')
@@ -562,7 +575,8 @@ class HostYuanChat(BaseHostChatLLM):
562
575
  def _llm_type(self) -> str:
563
576
  """Return type of chat model."""
564
577
  return 'yuan2'
565
-
578
+
579
+
566
580
  class HostYiChat(BaseHostChatLLM):
567
581
  # use custom llm chat api, api should compatiable with openai definition
568
582
  model_name: str = Field('Yi-34B-Chat', alias='model')
@@ -13,7 +13,7 @@ from langchain.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackM
13
13
  from langchain.chat_models.base import BaseChatModel
14
14
  from langchain.schema import ChatGeneration, ChatResult
15
15
  from langchain.schema.messages import (AIMessage, BaseMessage, ChatMessage, FunctionMessage,
16
- HumanMessage, SystemMessage)
16
+ HumanMessage, SystemMessage, ToolMessage)
17
17
  from langchain.utils import get_from_dict_or_env
18
18
  from langchain_core.pydantic_v1 import Field, root_validator
19
19
  from tenacity import (before_sleep_log, retry, retry_if_exception_type, stop_after_attempt,
@@ -60,11 +60,24 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
60
60
  additional_kwargs = {'function_call': dict(_dict['function_call'])}
61
61
  else:
62
62
  additional_kwargs = {}
63
+ if _dict.get("tool_calls"):
64
+ additional_kwargs = {'tool_calls': _dict['tool_calls']}
65
+ else:
66
+ additional_kwargs = {}
63
67
  return AIMessage(content=content, additional_kwargs=additional_kwargs)
64
68
  elif role == 'system':
65
69
  return SystemMessage(content=_dict['content'])
66
70
  elif role == 'function':
67
71
  return FunctionMessage(content=_dict['content'], name=_dict['name'])
72
+ elif role == "tool":
73
+ additional_kwargs = {}
74
+ if "name" in _dict:
75
+ additional_kwargs["name"] = _dict["name"]
76
+ return ToolMessage(
77
+ content=_dict.get("content", ""),
78
+ tool_call_id=_dict.get("tool_call_id"),
79
+ additional_kwargs=additional_kwargs,
80
+ )
68
81
  else:
69
82
  return ChatMessage(content=_dict['content'], role=role)
70
83
 
@@ -78,6 +91,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
78
91
  message_dict = {'role': 'assistant', 'content': message.content}
79
92
  if 'function_call' in message.additional_kwargs:
80
93
  message_dict['function_call'] = message.additional_kwargs['function_call']
94
+ if "tool_calls" in message.additional_kwargs:
95
+ message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
81
96
  elif isinstance(message, SystemMessage):
82
97
  message_dict = {'role': 'system', 'content': message.content}
83
98
  elif isinstance(message, FunctionMessage):
@@ -86,6 +101,12 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
86
101
  'content': message.content,
87
102
  'name': message.name,
88
103
  }
104
+ elif isinstance(message, ToolMessage):
105
+ message_dict = {
106
+ "role": "tool",
107
+ "content": message.content,
108
+ "tool_call_id": message.tool_call_id,
109
+ }
89
110
  else:
90
111
  raise ValueError(f'Got unknown type {message}')
91
112
  if 'name' in message.additional_kwargs:
@@ -281,7 +302,7 @@ class ChatQWen(BaseChatModel):
281
302
  inner_completion = ''
282
303
  role = 'assistant'
283
304
  params['stream'] = True
284
- function_call: Optional[dict] = None
305
+ tool_calls: Optional[list[dict]] = None
285
306
  async for is_error, stream_resp in self.acompletion_with_retry(messages=message_dicts,
286
307
  **params):
287
308
  output = None
@@ -297,18 +318,18 @@ class ChatQWen(BaseChatModel):
297
318
  role = choice['message'].get('role', role)
298
319
  token = choice['message'].get('content', '')
299
320
  inner_completion += token or ''
300
- _function_call = choice['message'].get('function_call')
321
+ _tool_calls = choice['message'].get('tool_calls')
301
322
  if run_manager:
302
323
  await run_manager.on_llm_new_token(token)
303
- if _function_call:
304
- if function_call is None:
305
- function_call = _function_call
324
+ if _tool_calls:
325
+ if tool_calls is None:
326
+ tool_calls = _tool_calls
306
327
  else:
307
- function_call['arguments'] += _function_call['arguments']
328
+ tool_calls[0]['arguments'] += _tool_calls[0]['arguments']
308
329
  message = _convert_dict_to_message({
309
330
  'content': inner_completion,
310
331
  'role': role,
311
- 'function_call': function_call,
332
+ 'tool_calls': tool_calls,
312
333
  })
313
334
  return ChatResult(generations=[ChatGeneration(message=message)])
314
335
  else:
@@ -164,7 +164,7 @@ class CustomKVLoader(BaseLoader):
164
164
  raise Exception('custom_kv parse_error')
165
165
  else:
166
166
  logger.error(f'custom_kv=create_task resp={resp.text}')
167
- raise Exception('custom_kv create task file')
167
+ raise Exception('custom_kv create task fail')
168
168
  content = json.dumps(document_result)
169
169
  doc = Document(page_content=content)
170
170
  return [doc]
@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
6
  import requests
7
7
  from langchain.embeddings.base import Embeddings
8
8
  from langchain.utils import get_from_dict_or_env
9
- from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
9
+ from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
10
10
  from tenacity import (before_sleep_log, retry, retry_if_exception_type, stop_after_attempt,
11
11
  wait_exponential)
12
12
 
@@ -64,11 +64,6 @@ class HostEmbeddings(BaseModel, Embeddings):
64
64
 
65
65
  url_ep: Optional[str] = None
66
66
 
67
- class Config:
68
- """Configuration for this pydantic object."""
69
-
70
- extra = Extra.forbid
71
-
72
67
  @root_validator()
73
68
  def validate_environment(cls, values: Dict) -> Dict:
74
69
  """Validate that api key and python package exists in environment."""
@@ -108,12 +103,14 @@ class HostEmbeddings(BaseModel, Embeddings):
108
103
  len_text = len(texts)
109
104
  while start_index < len_text:
110
105
  inp_local = {
111
- 'texts':texts[start_index:min(start_index + max_text_to_split, len_text)],
112
- 'model':self.model,
113
- 'type':emb_type
114
- }
106
+ 'texts': texts[start_index:min(start_index + max_text_to_split, len_text)],
107
+ 'model': self.model,
108
+ 'type': emb_type
109
+ }
115
110
  try:
116
- outp_single = self.client(url=self.url_ep, json=inp_local, timeout=self.request_timeout).json()
111
+ outp_single = self.client(url=self.url_ep,
112
+ json=inp_local,
113
+ timeout=self.request_timeout).json()
117
114
  if outp is None:
118
115
  outp = outp_single
119
116
  else:
@@ -162,6 +159,7 @@ class JINAEmbedding(HostEmbeddings):
162
159
  model: str = 'jina'
163
160
  embedding_ctx_length: int = 512
164
161
 
162
+
165
163
  class CustomHostEmbedding(HostEmbeddings):
166
164
  model: str = Field('custom-embedding', alias='model')
167
165
  embedding_ctx_length: int = 512
File without changes
@@ -0,0 +1,10 @@
1
+ from bisheng_langchain.gpts.agent_types.llm_functions_agent import (
2
+ get_openai_functions_agent_executor,
3
+ get_qwen_local_functions_agent_executor
4
+ )
5
+
6
+
7
+ __all__ = [
8
+ "get_openai_functions_agent_executor",
9
+ "get_qwen_local_functions_agent_executor"
10
+ ]
@@ -0,0 +1,220 @@
1
+ import json
2
+
3
+ from bisheng_langchain.gpts.message_types import LiberalFunctionMessage, LiberalToolMessage
4
+ from langchain.tools import BaseTool
5
+ from langchain.tools.render import format_tool_to_openai_tool
6
+ from langchain_core.language_models.base import LanguageModelLike
7
+ from langchain_core.messages import FunctionMessage, SystemMessage, ToolMessage
8
+ from langgraph.graph import END
9
+ from langgraph.graph.message import MessageGraph
10
+ from langgraph.prebuilt import ToolExecutor, ToolInvocation
11
+
12
+
13
+ def get_openai_functions_agent_executor(tools: list[BaseTool], llm: LanguageModelLike,
14
+ system_message: str, interrupt_before_action: bool,
15
+ **kwargs):
16
+
17
+ async def _get_messages(messages):
18
+ msgs = []
19
+ for m in messages:
20
+ if isinstance(m, LiberalToolMessage):
21
+ _dict = m.dict()
22
+ _dict['content'] = str(_dict['content'])
23
+ m_c = ToolMessage(**_dict)
24
+ msgs.append(m_c)
25
+ else:
26
+ msgs.append(m)
27
+
28
+ return [SystemMessage(content=system_message)] + msgs
29
+
30
+ if tools:
31
+ llm_with_tools = llm.bind(tools=[format_tool_to_openai_tool(t) for t in tools])
32
+ else:
33
+ llm_with_tools = llm
34
+ agent = _get_messages | llm_with_tools
35
+ tool_executor = ToolExecutor(tools)
36
+
37
+ # Define the function that determines whether to continue or not
38
+ def should_continue(messages):
39
+ last_message = messages[-1]
40
+ # If there is no function call, then we finish
41
+ if 'tool_calls' not in last_message.additional_kwargs:
42
+ return 'end'
43
+ # Otherwise if there is, we continue
44
+ else:
45
+ return 'continue'
46
+
47
+ # Define the function to execute tools
48
+ async def call_tool(messages):
49
+ actions: list[ToolInvocation] = []
50
+ # Based on the continue condition
51
+ # we know the last message involves a function call
52
+ last_message = messages[-1]
53
+ for tool_call in last_message.additional_kwargs['tool_calls']:
54
+ function = tool_call['function']
55
+ function_name = function['name']
56
+ _tool_input = json.loads(function['arguments'] or '{}')
57
+ # We construct an ToolInvocation from the function_call
58
+ actions.append(ToolInvocation(
59
+ tool=function_name,
60
+ tool_input=_tool_input,
61
+ ))
62
+ # We call the tool_executor and get back a response
63
+ responses = await tool_executor.abatch(actions, **kwargs)
64
+ # We use the response to create a ToolMessage
65
+ tool_messages = [
66
+ LiberalToolMessage(
67
+ tool_call_id=tool_call['id'],
68
+ content=response,
69
+ additional_kwargs={'name': tool_call['function']['name']},
70
+ )
71
+ for tool_call, response in zip(last_message.additional_kwargs['tool_calls'], responses)
72
+ ]
73
+ return tool_messages
74
+
75
+ workflow = MessageGraph()
76
+
77
+ # Define the two nodes we will cycle between
78
+ workflow.add_node('agent', agent)
79
+ workflow.add_node('action', call_tool)
80
+
81
+ # Set the entrypoint as `agent`
82
+ # This means that this node is the first one called
83
+ workflow.set_entry_point('agent')
84
+
85
+ # We now add a conditional edge
86
+ workflow.add_conditional_edges(
87
+ # First, we define the start node. We use `agent`.
88
+ # This means these are the edges taken after the `agent` node is called.
89
+ 'agent',
90
+ # Next, we pass in the function that will determine which node is called next.
91
+ should_continue,
92
+ # Finally we pass in a mapping.
93
+ # The keys are strings, and the values are other nodes.
94
+ # END is a special node marking that the graph should finish.
95
+ # What will happen is we will call `should_continue`, and then the output of that
96
+ # will be matched against the keys in this mapping.
97
+ # Based on which one it matches, that node will then be called.
98
+ {
99
+ # If `tools`, then we call the tool node.
100
+ 'continue': 'action',
101
+ # Otherwise we finish.
102
+ 'end': END,
103
+ },
104
+ )
105
+
106
+ # We now add a normal edge from `tools` to `agent`.
107
+ # This means that after `tools` is called, `agent` node is called next.
108
+ workflow.add_edge('action', 'agent')
109
+
110
+ # Finally, we compile it!
111
+ # This compiles it into a LangChain Runnable,
112
+ # meaning you can use it as you would any other runnable
113
+ app = workflow.compile()
114
+ if interrupt_before_action:
115
+ app.interrupt = ['action:inbox']
116
+ return app
117
+
118
+
119
+ def get_qwen_local_functions_agent_executor(
120
+ tools: list[BaseTool],
121
+ llm: LanguageModelLike,
122
+ system_message: str,
123
+ interrupt_before_action: bool,
124
+ **kwargs,
125
+ ):
126
+
127
+ async def _get_messages(messages):
128
+ msgs = []
129
+ for m in messages:
130
+ if isinstance(m, LiberalFunctionMessage):
131
+ _dict = m.dict()
132
+ _dict['content'] = str(_dict['content'])
133
+ m_c = FunctionMessage(**_dict)
134
+ msgs.append(m_c)
135
+ else:
136
+ msgs.append(m)
137
+
138
+ return [SystemMessage(content=system_message)] + msgs
139
+
140
+ if tools:
141
+ llm_with_tools = llm.bind(
142
+ functions=[format_tool_to_openai_tool(t)['function'] for t in tools])
143
+ else:
144
+ llm_with_tools = llm
145
+ agent = _get_messages | llm_with_tools
146
+ tool_executor = ToolExecutor(tools)
147
+
148
+ # Define the function that determines whether to continue or not
149
+ def should_continue(messages):
150
+ last_message = messages[-1]
151
+ # If there is no function call, then we finish
152
+ if 'function_call' not in last_message.additional_kwargs:
153
+ return 'end'
154
+ # Otherwise if there is, we continue
155
+ else:
156
+ return 'continue'
157
+
158
+ # Define the function to execute tools
159
+ async def call_tool(messages):
160
+ actions: list[ToolInvocation] = []
161
+ # Based on the continue condition
162
+ # we know the last message involves a function call
163
+ last_message = messages[-1]
164
+ # only one function
165
+ function = last_message.additional_kwargs['function_call']
166
+ function_name = function['name']
167
+ _tool_input = json.loads(function['arguments'] or '{}')
168
+ # We construct an ToolInvocation from the function_call
169
+ actions.append(ToolInvocation(
170
+ tool=function_name,
171
+ tool_input=_tool_input,
172
+ ))
173
+ # We call the tool_executor and get back a response
174
+ responses = await tool_executor.abatch(actions, **kwargs)
175
+ # We use the response to create a ToolMessage
176
+ tool_messages = [LiberalFunctionMessage(content=responses[0], name=function_name)]
177
+ return tool_messages
178
+
179
+ workflow = MessageGraph()
180
+
181
+ # Define the two nodes we will cycle between
182
+ workflow.add_node('agent', agent)
183
+ workflow.add_node('action', call_tool)
184
+
185
+ # Set the entrypoint as `agent`
186
+ # This means that this node is the first one called
187
+ workflow.set_entry_point('agent')
188
+
189
+ # We now add a conditional edge
190
+ workflow.add_conditional_edges(
191
+ # First, we define the start node. We use `agent`.
192
+ # This means these are the edges taken after the `agent` node is called.
193
+ 'agent',
194
+ # Next, we pass in the function that will determine which node is called next.
195
+ should_continue,
196
+ # Finally we pass in a mapping.
197
+ # The keys are strings, and the values are other nodes.
198
+ # END is a special node marking that the graph should finish.
199
+ # What will happen is we will call `should_continue`, and then the output of that
200
+ # will be matched against the keys in this mapping.
201
+ # Based on which one it matches, that node will then be called.
202
+ {
203
+ # If `tools`, then we call the tool node.
204
+ 'continue': 'action',
205
+ # Otherwise we finish.
206
+ 'end': END,
207
+ },
208
+ )
209
+
210
+ # We now add a normal edge from `tools` to `agent`.
211
+ # This means that after `tools` is called, `agent` node is called next.
212
+ workflow.add_edge('action', 'agent')
213
+
214
+ # Finally, we compile it!
215
+ # This compiles it into a LangChain Runnable,
216
+ # meaning you can use it as you would any other runnable
217
+ app = workflow.compile()
218
+ if interrupt_before_action:
219
+ app.interrupt = ['action:inbox']
220
+ return app