alita-sdk 0.3.181__py3-none-any.whl → 0.3.182__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alita_sdk/runtime/clients/client.py +2 -0
- alita_sdk/runtime/langchain/assistant.py +8 -57
- alita_sdk/runtime/langchain/langraph_agent.py +1 -1
- alita_sdk/runtime/llms/alita.py +11 -4
- alita_sdk/runtime/utils/streamlit.py +2 -7
- {alita_sdk-0.3.181.dist-info → alita_sdk-0.3.182.dist-info}/METADATA +2 -2
- {alita_sdk-0.3.181.dist-info → alita_sdk-0.3.182.dist-info}/RECORD +10 -11
- alita_sdk/runtime/langchain/agents/react_agent.py +0 -157
- {alita_sdk-0.3.181.dist-info → alita_sdk-0.3.182.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.181.dist-info → alita_sdk-0.3.182.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.181.dist-info → alita_sdk-0.3.182.dist-info}/top_level.txt +0 -0
@@ -218,6 +218,8 @@ class AlitaClient:
|
|
218
218
|
return LangChainAssistant(self, data, client,
|
219
219
|
chat_history, app_type,
|
220
220
|
tools=tools, memory=memory, store=store).runnable()
|
221
|
+
elif runtime == 'llama':
|
222
|
+
raise NotImplementedError("LLama runtime is not supported")
|
221
223
|
|
222
224
|
def datasource(self, datasource_id: int) -> AlitaDataSource:
|
223
225
|
url = f"{self.datasources}/{datasource_id}"
|
@@ -7,15 +7,10 @@ from langchain.agents import (
|
|
7
7
|
create_json_chat_agent)
|
8
8
|
from langgraph.store.base import BaseStore
|
9
9
|
|
10
|
-
# Note: Traditional LangChain agents (OpenAI, XML, JSON) do not support
|
11
|
-
# checkpointing natively. Only LangGraph agents support checkpointing.
|
12
|
-
# For checkpointing in traditional agents, you need to migrate to LangGraph.
|
13
|
-
|
14
10
|
from .agents.xml_chat import create_xml_chat_agent
|
15
|
-
from .agents.react_agent import get_langgraph_agent_with_auto_continue
|
16
11
|
from .langraph_agent import create_graph
|
17
12
|
from langchain_core.messages import (
|
18
|
-
BaseMessage, SystemMessage, HumanMessage
|
13
|
+
BaseMessage, SystemMessage, HumanMessage
|
19
14
|
)
|
20
15
|
from langchain_core.prompts import MessagesPlaceholder
|
21
16
|
from .constants import REACT_ADDON, REACT_VARS, XML_ADDON
|
@@ -34,8 +29,7 @@ class Assistant:
|
|
34
29
|
app_type: str = "openai",
|
35
30
|
tools: Optional[list] = [],
|
36
31
|
memory: Optional[Any] = None,
|
37
|
-
store: Optional[BaseStore] = None
|
38
|
-
use_checkpointing: bool = False):
|
32
|
+
store: Optional[BaseStore] = None):
|
39
33
|
|
40
34
|
self.client = copy(client)
|
41
35
|
self.client.max_tokens = data['llm_settings']['max_tokens']
|
@@ -48,7 +42,6 @@ class Assistant:
|
|
48
42
|
self.app_type = app_type
|
49
43
|
self.memory = memory
|
50
44
|
self.store = store
|
51
|
-
self.use_checkpointing = use_checkpointing
|
52
45
|
|
53
46
|
logger.debug("Data for agent creation: %s", data)
|
54
47
|
logger.info("App type: %s", app_type)
|
@@ -76,31 +69,13 @@ class Assistant:
|
|
76
69
|
self.tools = get_tools(data['tools'], alita_client=alita, llm=self.client, memory_store=self.store)
|
77
70
|
if app_type == "pipeline":
|
78
71
|
self.prompt = data['instructions']
|
79
|
-
elif app_type == "react":
|
80
|
-
self.tools += tools
|
81
|
-
messages = [SystemMessage(content=data['instructions'])]
|
82
|
-
# messages.append(HumanMessage("{{input}}"))
|
83
|
-
self.prompt = Jinja2TemplatedChatMessagesTemplate(messages=messages)
|
84
|
-
variables = {}
|
85
|
-
input_variables = []
|
86
|
-
for variable in data['variables']:
|
87
|
-
if variable['value'] != "":
|
88
|
-
variables[variable['name']] = variable['value']
|
89
|
-
else:
|
90
|
-
input_variables.append(variable['name'])
|
91
|
-
if input_variables:
|
92
|
-
self.prompt.input_variables = input_variables
|
93
|
-
if variables:
|
94
|
-
self.prompt.partial_variables = variables
|
95
|
-
try:
|
96
|
-
logger.info(f"Client was created with client setting: temperature - {self.client._get_model_default_parameters}")
|
97
|
-
except Exception as e:
|
98
|
-
logger.info(f"Client was created with client setting: temperature - {self.client.temperature} : {self.client.max_tokens}")
|
99
72
|
else:
|
100
73
|
self.tools += tools
|
101
74
|
messages = [SystemMessage(content=data['instructions'])]
|
102
75
|
messages.append(MessagesPlaceholder("chat_history"))
|
103
|
-
if app_type == "
|
76
|
+
if app_type == "react":
|
77
|
+
messages.append(HumanMessage(REACT_ADDON))
|
78
|
+
elif app_type == "xml":
|
104
79
|
messages.append(HumanMessage(XML_ADDON))
|
105
80
|
elif app_type in ['openai', 'dial']:
|
106
81
|
messages.append(HumanMessage("{{input}}"))
|
@@ -112,7 +87,7 @@ class Assistant:
|
|
112
87
|
variables[variable['name']] = variable['value']
|
113
88
|
else:
|
114
89
|
input_variables.append(variable['name'])
|
115
|
-
if app_type in ["xml"]:
|
90
|
+
if app_type in ["react", "xml"]:
|
116
91
|
input_variables = list(set(input_variables + REACT_VARS))
|
117
92
|
|
118
93
|
if chat_history and isinstance(chat_history, list):
|
@@ -143,17 +118,12 @@ class Assistant:
|
|
143
118
|
if self.app_type == 'pipeline':
|
144
119
|
return self.pipeline()
|
145
120
|
elif self.app_type == 'openai':
|
146
|
-
# Check if checkpointing is enabled - if so, use LangGraph for auto-continue capability
|
147
|
-
if self.use_checkpointing:
|
148
|
-
return self.getLangGraphAgentWithAutoContinue()
|
149
121
|
return self.getOpenAIToolsAgentExecutor()
|
150
122
|
elif self.app_type == 'xml':
|
151
|
-
# Check if checkpointing is enabled - if so, use LangGraph for auto-continue capability
|
152
|
-
if self.use_checkpointing:
|
153
|
-
return self.getLangGraphAgentWithAutoContinue()
|
154
123
|
return self.getXMLAgentExecutor()
|
155
124
|
else:
|
156
|
-
|
125
|
+
self.tools = [EchoTool()] + self.tools
|
126
|
+
return self.getAgentExecutor()
|
157
127
|
|
158
128
|
def _agent_executor(self, agent: Any):
|
159
129
|
return AgentExecutor.from_agent_and_tools(agent=agent, tools=self.tools,
|
@@ -199,22 +169,3 @@ class Assistant:
|
|
199
169
|
|
200
170
|
def predict(self, messages: list[BaseMessage]):
|
201
171
|
return self.client.invoke(messages)
|
202
|
-
|
203
|
-
def getLangGraphAgentWithAutoContinue(self):
|
204
|
-
"""
|
205
|
-
Create a LangGraph agent with auto-continue capability for when responses get truncated.
|
206
|
-
This provides better handling of length-limited responses compared to traditional AgentExecutor.
|
207
|
-
Uses simple in-memory checkpointing for auto-continue functionality.
|
208
|
-
|
209
|
-
Note: Requires LangGraph 0.5.x or higher that supports post_model_hook.
|
210
|
-
"""
|
211
|
-
|
212
|
-
# Use the function from react_agent.py to create the agent
|
213
|
-
return get_langgraph_agent_with_auto_continue(
|
214
|
-
prompt=self.prompt,
|
215
|
-
model=self.client,
|
216
|
-
tools=self.tools,
|
217
|
-
memory=self.memory,
|
218
|
-
debug=True
|
219
|
-
)
|
220
|
-
|
@@ -12,7 +12,7 @@ from langchain_core.runnables import RunnableConfig
|
|
12
12
|
from langchain_core.tools import BaseTool
|
13
13
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
14
14
|
from langgraph.graph import StateGraph
|
15
|
-
from langgraph.graph import END, START
|
15
|
+
from langgraph.graph.graph import END, START
|
16
16
|
from langgraph.graph.state import CompiledStateGraph
|
17
17
|
from langgraph.managed.base import is_managed_value
|
18
18
|
from langgraph.prebuilt import InjectedStore
|
alita_sdk/runtime/llms/alita.py
CHANGED
@@ -12,24 +12,31 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
|
16
|
+
#
|
17
|
+
# This is adoption of https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/chat_models/openai.py
|
18
|
+
#
|
19
|
+
|
15
20
|
import logging
|
16
21
|
import requests
|
17
22
|
from time import sleep
|
18
23
|
from traceback import format_exc
|
19
24
|
|
20
|
-
from typing import Any, List, Optional, AsyncIterator, Dict, Iterator
|
25
|
+
from typing import Any, List, Optional, AsyncIterator, Dict, Iterator, Mapping, Type
|
21
26
|
from tiktoken import get_encoding, encoding_for_model
|
22
27
|
from langchain_core.callbacks import (
|
23
28
|
AsyncCallbackManagerForLLMRun,
|
24
29
|
CallbackManagerForLLMRun,
|
25
30
|
)
|
26
|
-
from langchain_core.language_models import BaseChatModel
|
27
|
-
from langchain_core.messages import AIMessageChunk, BaseMessage
|
31
|
+
from langchain_core.language_models import BaseChatModel, SimpleChatModel
|
32
|
+
from langchain_core.messages import (AIMessageChunk, BaseMessage, HumanMessage, HumanMessageChunk, ChatMessageChunk,
|
33
|
+
FunctionMessageChunk, SystemMessageChunk, ToolMessageChunk, BaseMessageChunk,
|
34
|
+
AIMessage)
|
28
35
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
29
36
|
from langchain_core.runnables import run_in_executor
|
30
37
|
from langchain_community.chat_models.openai import generate_from_stream, _convert_delta_to_message_chunk
|
31
38
|
from ..clients import AlitaClient
|
32
|
-
from pydantic import Field, model_validator
|
39
|
+
from pydantic import Field, model_validator, field_validator, ValidationInfo
|
33
40
|
|
34
41
|
logger = logging.getLogger(__name__)
|
35
42
|
|
@@ -1038,14 +1038,9 @@ def run_streamlit(st, ai_icon=None, user_icon=None):
|
|
1038
1038
|
{"input": prompt, "chat_history": st.session_state.messages[:-1]},
|
1039
1039
|
{ 'callbacks': [st_cb], 'configurable': {"thread_id": st.session_state.thread_id}}
|
1040
1040
|
)
|
1041
|
-
|
1042
|
-
output =response["output"]
|
1043
|
-
except KeyError:
|
1044
|
-
logger.info(response)
|
1045
|
-
output = response['messages'][0].content
|
1046
|
-
st.write(output)
|
1041
|
+
st.write(response["output"])
|
1047
1042
|
st.session_state.thread_id = response.get("thread_id", None)
|
1048
|
-
st.session_state.messages.append({"role": "assistant", "content": output})
|
1043
|
+
st.session_state.messages.append({"role": "assistant", "content": response["output"]})
|
1049
1044
|
|
1050
1045
|
elif st.session_state.llm and st.session_state.show_toolkit_testing and st.session_state.configured_toolkit:
|
1051
1046
|
# Toolkit Testing Main View
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: alita_sdk
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.182
|
4
4
|
Summary: SDK for building langchain agents using resources from Alita
|
5
5
|
Author-email: Artem Rozumenko <artyom.rozumenko@gmail.com>, Mikalai Biazruchka <mikalai_biazruchka@epam.com>, Roman Mitusov <roman_mitusov@epam.com>, Ivan Krakhmaliuk <lifedjik@gmail.com>, Artem Dubrovskiy <ad13box@gmail.com>
|
6
6
|
License-Expression: Apache-2.0
|
@@ -32,7 +32,7 @@ Requires-Dist: langchain-openai~=0.3.0; extra == "runtime"
|
|
32
32
|
Requires-Dist: langgraph-checkpoint-sqlite~=2.0.0; extra == "runtime"
|
33
33
|
Requires-Dist: langgraph-checkpoint-postgres~=2.0.1; extra == "runtime"
|
34
34
|
Requires-Dist: langsmith>=0.3.45; extra == "runtime"
|
35
|
-
Requires-Dist: langgraph
|
35
|
+
Requires-Dist: langgraph<0.5,>=0.4.8; extra == "runtime"
|
36
36
|
Requires-Dist: langchain_chroma~=0.2.2; extra == "runtime"
|
37
37
|
Requires-Dist: langchain-unstructured~=0.1.6; extra == "runtime"
|
38
38
|
Requires-Dist: langchain-postgres~=0.0.13; extra == "runtime"
|
@@ -13,21 +13,20 @@ alita_sdk/community/analysis/jira_analyse/api_wrapper.py,sha256=Ui1GBWizIFGFOi98
|
|
13
13
|
alita_sdk/runtime/__init__.py,sha256=4W0UF-nl3QF2bvET5lnah4o24CoTwSoKXhuN0YnwvEE,828
|
14
14
|
alita_sdk/runtime/clients/__init__.py,sha256=BdehU5GBztN1Qi1Wul0cqlU46FxUfMnI6Vq2Zd_oq1M,296
|
15
15
|
alita_sdk/runtime/clients/artifact.py,sha256=4N2t5x3GibyXLq3Fvrv2o_VA7Z000yNfc-UN4eGsHZg,2679
|
16
|
-
alita_sdk/runtime/clients/client.py,sha256=
|
16
|
+
alita_sdk/runtime/clients/client.py,sha256=6ezOJ92CSw6b2PVs4uFMQKQdp40uT1awoFEqWAfBH_A,20029
|
17
17
|
alita_sdk/runtime/clients/datasource.py,sha256=HAZovoQN9jBg0_-lIlGBQzb4FJdczPhkHehAiVG3Wx0,1020
|
18
18
|
alita_sdk/runtime/clients/prompt.py,sha256=li1RG9eBwgNK_Qf0qUaZ8QNTmsncFrAL2pv3kbxZRZg,1447
|
19
19
|
alita_sdk/runtime/langchain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
|
-
alita_sdk/runtime/langchain/assistant.py,sha256=
|
20
|
+
alita_sdk/runtime/langchain/assistant.py,sha256=QJEMiEOrFMJ4GpnK24U2pKFblrvdQpKFdfhZsI2wAUI,7507
|
21
21
|
alita_sdk/runtime/langchain/chat_message_template.py,sha256=kPz8W2BG6IMyITFDA5oeb5BxVRkHEVZhuiGl4MBZKdc,2176
|
22
22
|
alita_sdk/runtime/langchain/constants.py,sha256=eHVJ_beJNTf1WJo4yq7KMK64fxsRvs3lKc34QCXSbpk,3319
|
23
23
|
alita_sdk/runtime/langchain/indexer.py,sha256=0ENHy5EOhThnAiYFc7QAsaTNp9rr8hDV_hTK8ahbatk,37592
|
24
|
-
alita_sdk/runtime/langchain/langraph_agent.py,sha256=
|
24
|
+
alita_sdk/runtime/langchain/langraph_agent.py,sha256=ac96v-Nr7t1x1NzaOpvV1VuML3LfoNNIDTDl8pb-bSY,40505
|
25
25
|
alita_sdk/runtime/langchain/mixedAgentParser.py,sha256=M256lvtsL3YtYflBCEp-rWKrKtcY1dJIyRGVv7KW9ME,2611
|
26
26
|
alita_sdk/runtime/langchain/mixedAgentRenderes.py,sha256=asBtKqm88QhZRILditjYICwFVKF5KfO38hu2O-WrSWE,5964
|
27
27
|
alita_sdk/runtime/langchain/store_manager.py,sha256=w5-0GbPGJAw14g0CCD9BKFMznzk1I-iJ5OGj_HZJZgA,2211
|
28
28
|
alita_sdk/runtime/langchain/utils.py,sha256=Npferkn10dvdksnKzLJLBI5bNGQyVWTBwqp3vQtUqmY,6631
|
29
29
|
alita_sdk/runtime/langchain/agents/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
30
|
-
alita_sdk/runtime/langchain/agents/react_agent.py,sha256=y461yEn2J7eP5GYbiPL3qaU3e9R8OmvyJNuFFvbZRCY,5971
|
31
30
|
alita_sdk/runtime/langchain/agents/xml_chat.py,sha256=Mx7PK5T97_GrFCwHHZ3JZP42S7MwtUzV0W-_8j6Amt8,6212
|
32
31
|
alita_sdk/runtime/langchain/document_loaders/AlitaBDDScenariosLoader.py,sha256=4kFU1ijrM1Jw7cywQv8mUiBHlE6w-uqfzSZP4hUV5P4,3771
|
33
32
|
alita_sdk/runtime/langchain/document_loaders/AlitaCSVLoader.py,sha256=TBJuIFqweLDtd0JxgfPqrcY5eED-M617CT_EInp6Lmg,1949
|
@@ -64,7 +63,7 @@ alita_sdk/runtime/langchain/tools/bdd_parser/bdd_parser.py,sha256=DiEEOqDef2Xo3x
|
|
64
63
|
alita_sdk/runtime/langchain/tools/bdd_parser/feature_types.py,sha256=l3AdjSQnNv1CE1NuHi7wts6h6AsCiK-iPu0PnPf3jf0,399
|
65
64
|
alita_sdk/runtime/langchain/tools/bdd_parser/parser.py,sha256=1H1Nd_OH5Wx8A5YV1zUghBxo613yPptZ7fqNo8Eg48M,17289
|
66
65
|
alita_sdk/runtime/llms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
67
|
-
alita_sdk/runtime/llms/alita.py,sha256
|
66
|
+
alita_sdk/runtime/llms/alita.py,sha256=oAALCyrTbQ8gygJO7xzZLjOowpUnTbh8R363argRUVs,10119
|
68
67
|
alita_sdk/runtime/llms/preloaded.py,sha256=3AaUbZK3d8fvxAQMjR3ftOoYa0SnkCOL1EvdvDCXIHE,11321
|
69
68
|
alita_sdk/runtime/toolkits/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
70
69
|
alita_sdk/runtime/toolkits/application.py,sha256=akqUuaIL9u7-SsUmS-XgN4qxDEnXFhsK9do4n8inpSo,2432
|
@@ -97,7 +96,7 @@ alita_sdk/runtime/utils/constants.py,sha256=Xntx1b_uxUzT4clwqHA_U6K8y5bBqf_4lSQw
|
|
97
96
|
alita_sdk/runtime/utils/evaluate.py,sha256=iM1P8gzBLHTuSCe85_Ng_h30m52hFuGuhNXJ7kB1tgI,1872
|
98
97
|
alita_sdk/runtime/utils/logging.py,sha256=svPyiW8ztDfhqHFITv5FBCj8UhLxz6hWcqGIY6wpJJE,3331
|
99
98
|
alita_sdk/runtime/utils/save_dataframe.py,sha256=i-E1wp-t4wb17Zq3nA3xYwgSILjoXNizaQAA9opWvxY,1576
|
100
|
-
alita_sdk/runtime/utils/streamlit.py,sha256=
|
99
|
+
alita_sdk/runtime/utils/streamlit.py,sha256=z4J_bdxkA0zMROkvTB4u379YBRFCkKh-h7PD8RlnZWQ,85644
|
101
100
|
alita_sdk/runtime/utils/utils.py,sha256=dM8whOJAuFJFe19qJ69-FLzrUp6d2G-G6L7d4ss2XqM,346
|
102
101
|
alita_sdk/tools/__init__.py,sha256=l3KxV-Qtu-04QQ9YYcovbLtEkZ30fGWDZ7o9nuRF16o,9967
|
103
102
|
alita_sdk/tools/elitea_base.py,sha256=NQaIxPX6DVIerHCb18jwUR6maZxxk73NZaTsFHkBQWE,21119
|
@@ -296,8 +295,8 @@ alita_sdk/tools/zephyr_scale/api_wrapper.py,sha256=UHVQUVqcBc3SZvDfO78HSuBzwAsRw
|
|
296
295
|
alita_sdk/tools/zephyr_squad/__init__.py,sha256=rq4jOb3lRW2GXvAguk4H1KinO5f-zpygzhBJf-E1Ucw,2773
|
297
296
|
alita_sdk/tools/zephyr_squad/api_wrapper.py,sha256=iOMxyE7vOc_LwFB_nBMiSFXkNtvbptA4i-BrTlo7M0A,5854
|
298
297
|
alita_sdk/tools/zephyr_squad/zephyr_squad_cloud_client.py,sha256=IYUJoMFOMA70knLhLtAnuGoy3OK80RuqeQZ710oyIxE,3631
|
299
|
-
alita_sdk-0.3.
|
300
|
-
alita_sdk-0.3.
|
301
|
-
alita_sdk-0.3.
|
302
|
-
alita_sdk-0.3.
|
303
|
-
alita_sdk-0.3.
|
298
|
+
alita_sdk-0.3.182.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
299
|
+
alita_sdk-0.3.182.dist-info/METADATA,sha256=WqwMApLsjTHN2hiq_uUi8e6SOE7PdJZZXf50UULLvPk,18753
|
300
|
+
alita_sdk-0.3.182.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
301
|
+
alita_sdk-0.3.182.dist-info/top_level.txt,sha256=0vJYy5p_jK6AwVb1aqXr7Kgqgk3WDtQ6t5C-XI9zkmg,10
|
302
|
+
alita_sdk-0.3.182.dist-info/RECORD,,
|
@@ -1,157 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
from typing import Any, Optional, Callable
|
3
|
-
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
|
4
|
-
from langchain_core.tools import BaseTool
|
5
|
-
from langgraph.prebuilt import create_react_agent
|
6
|
-
from langgraph.checkpoint.memory import MemorySaver
|
7
|
-
|
8
|
-
logger = logging.getLogger(__name__)
|
9
|
-
|
10
|
-
def create_react_agent_with_auto_continue(
|
11
|
-
prompt: Any,
|
12
|
-
model: Any,
|
13
|
-
tools: list[BaseTool],
|
14
|
-
memory: Optional[Any] = None,
|
15
|
-
debug: bool = False
|
16
|
-
) -> Any:
|
17
|
-
"""
|
18
|
-
Create a LangGraph React agent with auto-continue capability for when responses get truncated.
|
19
|
-
This provides better handling of length-limited responses compared to traditional AgentExecutor.
|
20
|
-
Uses simple in-memory checkpointing for auto-continue functionality.
|
21
|
-
|
22
|
-
Args:
|
23
|
-
prompt: The prompt template to use for the agent
|
24
|
-
model: The language model to use
|
25
|
-
tools: List of tools available to the agent
|
26
|
-
memory: Optional memory store for checkpointing (will create MemorySaver if None)
|
27
|
-
debug: Whether to enable debug mode
|
28
|
-
|
29
|
-
Returns:
|
30
|
-
A configured React agent with auto-continue capability
|
31
|
-
|
32
|
-
Note: Requires LangGraph 0.5.x or higher that supports post_model_hook.
|
33
|
-
"""
|
34
|
-
# Use simple in-memory checkpointer for auto-continue functionality if not provided
|
35
|
-
if memory is None:
|
36
|
-
memory = MemorySaver()
|
37
|
-
|
38
|
-
# Set up parameters for the agent
|
39
|
-
kwargs = {
|
40
|
-
"prompt": prompt,
|
41
|
-
"model": model,
|
42
|
-
"tools": tools,
|
43
|
-
"checkpointer": memory,
|
44
|
-
"post_model_hook": _create_auto_continue_hook() # Auto-continue hook
|
45
|
-
}
|
46
|
-
|
47
|
-
# Create the base React agent with langgraph's prebuilt function
|
48
|
-
base_agent = create_react_agent(**kwargs)
|
49
|
-
|
50
|
-
return base_agent
|
51
|
-
|
52
|
-
def _create_auto_continue_hook() -> Callable:
|
53
|
-
"""
|
54
|
-
Create a post-model hook for LangGraph 0.5.x that detects truncated responses
|
55
|
-
and adds continuation prompts.
|
56
|
-
This checks if the last AI message was truncated and automatically continues if needed.
|
57
|
-
"""
|
58
|
-
MAX_CONTINUATIONS = 3 # Maximum number of auto-continuations allowed
|
59
|
-
|
60
|
-
def post_model_hook(state):
|
61
|
-
messages = state.get("messages", [])
|
62
|
-
|
63
|
-
# Count how many auto-continue messages we've already sent
|
64
|
-
continuation_count = sum(
|
65
|
-
1 for msg in messages
|
66
|
-
if isinstance(msg, HumanMessage) and
|
67
|
-
"continue your previous response" in msg.content.lower()
|
68
|
-
)
|
69
|
-
|
70
|
-
# Don't continue if we've reached the limit
|
71
|
-
if continuation_count >= MAX_CONTINUATIONS:
|
72
|
-
return state
|
73
|
-
|
74
|
-
# Check if the last message is from AI and was truncated
|
75
|
-
if messages and isinstance(messages[-1], AIMessage):
|
76
|
-
last_ai_message = messages[-1]
|
77
|
-
|
78
|
-
# Check for truncation indicators
|
79
|
-
is_truncated = (
|
80
|
-
hasattr(last_ai_message, 'response_metadata') and
|
81
|
-
last_ai_message.response_metadata.get('finish_reason') == 'length'
|
82
|
-
) or (
|
83
|
-
# Fallback: check if message seems to end abruptly
|
84
|
-
last_ai_message.content and
|
85
|
-
not last_ai_message.content.rstrip().endswith(('.', '!', '?', ':', ';'))
|
86
|
-
)
|
87
|
-
|
88
|
-
# Add continuation request if truncated
|
89
|
-
if is_truncated:
|
90
|
-
logger.info("Detected truncated response, adding continuation request")
|
91
|
-
new_messages = messages.copy()
|
92
|
-
new_messages.append(HumanMessage(content="Continue your previous response from where you left off"))
|
93
|
-
return {"messages": new_messages}
|
94
|
-
|
95
|
-
return state
|
96
|
-
|
97
|
-
return post_model_hook
|
98
|
-
|
99
|
-
def get_langgraph_agent_with_auto_continue(
|
100
|
-
prompt: Any,
|
101
|
-
model: Any,
|
102
|
-
tools: list[BaseTool],
|
103
|
-
memory: Optional[Any] = None,
|
104
|
-
debug: bool = False
|
105
|
-
) -> Any:
|
106
|
-
"""
|
107
|
-
Create a LangGraph agent with auto-continue capability for when responses get truncated.
|
108
|
-
This provides better handling of length-limited responses compared to traditional AgentExecutor.
|
109
|
-
Uses simple in-memory checkpointing for auto-continue functionality.
|
110
|
-
|
111
|
-
Args:
|
112
|
-
prompt: The prompt template to use for the agent
|
113
|
-
model: The language model to use
|
114
|
-
tools: List of tools available to the agent
|
115
|
-
memory: Optional memory store for checkpointing (will create MemorySaver if None)
|
116
|
-
debug: Whether to enable debug mode
|
117
|
-
|
118
|
-
Returns:
|
119
|
-
A configured LangGraphAgentRunnable with auto-continue capability
|
120
|
-
|
121
|
-
Note: Requires LangGraph 0.5.x or higher that supports post_model_hook.
|
122
|
-
"""
|
123
|
-
from ...langchain.langraph_agent import LangGraphAgentRunnable
|
124
|
-
|
125
|
-
# Use simple in-memory checkpointer for auto-continue functionality if not provided
|
126
|
-
if memory is None:
|
127
|
-
memory = MemorySaver()
|
128
|
-
|
129
|
-
# Create the base React agent with auto-continue capability
|
130
|
-
base_agent = create_react_agent_with_auto_continue(
|
131
|
-
prompt=prompt,
|
132
|
-
model=model,
|
133
|
-
tools=tools,
|
134
|
-
memory=memory,
|
135
|
-
debug=debug
|
136
|
-
)
|
137
|
-
|
138
|
-
# Wrap the base agent in our custom LangGraphAgentRunnable to handle input properly
|
139
|
-
# This ensures that our invoke() input handling logic is applied
|
140
|
-
agent = LangGraphAgentRunnable(
|
141
|
-
builder=base_agent.builder,
|
142
|
-
config_type=base_agent.builder.config_schema,
|
143
|
-
nodes=base_agent.nodes,
|
144
|
-
channels=base_agent.channels,
|
145
|
-
input_channels=base_agent.input_channels,
|
146
|
-
stream_mode=base_agent.stream_mode,
|
147
|
-
output_channels=base_agent.output_channels,
|
148
|
-
stream_channels=base_agent.stream_channels,
|
149
|
-
checkpointer=memory,
|
150
|
-
interrupt_before_nodes=base_agent.interrupt_before_nodes,
|
151
|
-
interrupt_after_nodes=base_agent.interrupt_after_nodes,
|
152
|
-
debug=debug,
|
153
|
-
store=base_agent.store,
|
154
|
-
schema_to_mapper=base_agent.schema_to_mapper
|
155
|
-
)
|
156
|
-
|
157
|
-
return agent
|
File without changes
|
File without changes
|
File without changes
|