langchain 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/agents/agent.py +38 -21
- langchain/agents/agent_iterator.py +2 -2
- langchain/agents/agent_types.py +1 -1
- langchain/agents/openai_assistant/base.py +11 -15
- langchain/agents/openai_tools/base.py +8 -3
- langchain/chains/combine_documents/base.py +4 -4
- langchain/chains/combine_documents/stuff.py +1 -1
- langchain/chains/constitutional_ai/base.py +144 -1
- langchain/chains/conversation/base.py +1 -1
- langchain/chains/flare/base.py +46 -70
- langchain/chains/hyde/base.py +18 -8
- langchain/chains/llm_math/base.py +118 -1
- langchain/chains/moderation.py +8 -7
- langchain/chains/natbot/base.py +24 -10
- langchain/chains/query_constructor/parser.py +23 -0
- langchain/retrievers/document_compressors/chain_extract.py +19 -10
- langchain/retrievers/document_compressors/chain_filter.py +27 -10
- langchain/retrievers/re_phraser.py +7 -7
- langchain/retrievers/self_query/base.py +11 -2
- {langchain-0.2.14.dist-info → langchain-0.2.16.dist-info}/METADATA +2 -2
- {langchain-0.2.14.dist-info → langchain-0.2.16.dist-info}/RECORD +24 -24
- {langchain-0.2.14.dist-info → langchain-0.2.16.dist-info}/LICENSE +0 -0
- {langchain-0.2.14.dist-info → langchain-0.2.16.dist-info}/WHEEL +0 -0
- {langchain-0.2.14.dist-info → langchain-0.2.16.dist-info}/entry_points.txt +0 -0
langchain/agents/agent.py
CHANGED
|
@@ -19,6 +19,7 @@ from typing import (
|
|
|
19
19
|
Sequence,
|
|
20
20
|
Tuple,
|
|
21
21
|
Union,
|
|
22
|
+
cast,
|
|
22
23
|
)
|
|
23
24
|
|
|
24
25
|
import yaml
|
|
@@ -629,7 +630,7 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
|
|
629
630
|
|
|
630
631
|
@deprecated(
|
|
631
632
|
"0.1.0",
|
|
632
|
-
|
|
633
|
+
message=(
|
|
633
634
|
"Use new agent constructor methods like create_react_agent, create_json_agent, "
|
|
634
635
|
"create_structured_chat_agent, etc."
|
|
635
636
|
),
|
|
@@ -720,7 +721,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
|
|
720
721
|
|
|
721
722
|
@deprecated(
|
|
722
723
|
"0.1.0",
|
|
723
|
-
|
|
724
|
+
message=(
|
|
724
725
|
"Use new agent constructor methods like create_react_agent, create_json_agent, "
|
|
725
726
|
"create_structured_chat_agent, etc."
|
|
726
727
|
),
|
|
@@ -1042,12 +1043,13 @@ class ExceptionTool(BaseTool):
|
|
|
1042
1043
|
|
|
1043
1044
|
|
|
1044
1045
|
NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]]
|
|
1046
|
+
RunnableAgentType = Union[RunnableAgent, RunnableMultiActionAgent]
|
|
1045
1047
|
|
|
1046
1048
|
|
|
1047
1049
|
class AgentExecutor(Chain):
|
|
1048
1050
|
"""Agent that is using tools."""
|
|
1049
1051
|
|
|
1050
|
-
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent]
|
|
1052
|
+
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent, Runnable]
|
|
1051
1053
|
"""The agent to run for creating a plan and determining actions
|
|
1052
1054
|
to take at each step of the execution loop."""
|
|
1053
1055
|
tools: Sequence[BaseTool]
|
|
@@ -1095,7 +1097,7 @@ class AgentExecutor(Chain):
|
|
|
1095
1097
|
@classmethod
|
|
1096
1098
|
def from_agent_and_tools(
|
|
1097
1099
|
cls,
|
|
1098
|
-
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent],
|
|
1100
|
+
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent, Runnable],
|
|
1099
1101
|
tools: Sequence[BaseTool],
|
|
1100
1102
|
callbacks: Callbacks = None,
|
|
1101
1103
|
**kwargs: Any,
|
|
@@ -1172,6 +1174,21 @@ class AgentExecutor(Chain):
|
|
|
1172
1174
|
)
|
|
1173
1175
|
return values
|
|
1174
1176
|
|
|
1177
|
+
@property
|
|
1178
|
+
def _action_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
|
1179
|
+
"""Type cast self.agent.
|
|
1180
|
+
|
|
1181
|
+
The .agent attribute type includes Runnable, but is converted to one of
|
|
1182
|
+
RunnableAgentType in the validate_runnable_agent root_validator.
|
|
1183
|
+
|
|
1184
|
+
To support instantiating with a Runnable, here we explicitly cast the type
|
|
1185
|
+
to reflect the changes made in the root_validator.
|
|
1186
|
+
"""
|
|
1187
|
+
if isinstance(self.agent, Runnable):
|
|
1188
|
+
return cast(RunnableAgentType, self.agent)
|
|
1189
|
+
else:
|
|
1190
|
+
return self.agent
|
|
1191
|
+
|
|
1175
1192
|
def save(self, file_path: Union[Path, str]) -> None:
|
|
1176
1193
|
"""Raise error - saving not supported for Agent Executors.
|
|
1177
1194
|
|
|
@@ -1193,7 +1210,7 @@ class AgentExecutor(Chain):
|
|
|
1193
1210
|
Args:
|
|
1194
1211
|
file_path: Path to save to.
|
|
1195
1212
|
"""
|
|
1196
|
-
return self.
|
|
1213
|
+
return self._action_agent.save(file_path)
|
|
1197
1214
|
|
|
1198
1215
|
def iter(
|
|
1199
1216
|
self,
|
|
@@ -1228,7 +1245,7 @@ class AgentExecutor(Chain):
|
|
|
1228
1245
|
|
|
1229
1246
|
:meta private:
|
|
1230
1247
|
"""
|
|
1231
|
-
return self.
|
|
1248
|
+
return self._action_agent.input_keys
|
|
1232
1249
|
|
|
1233
1250
|
@property
|
|
1234
1251
|
def output_keys(self) -> List[str]:
|
|
@@ -1237,9 +1254,9 @@ class AgentExecutor(Chain):
|
|
|
1237
1254
|
:meta private:
|
|
1238
1255
|
"""
|
|
1239
1256
|
if self.return_intermediate_steps:
|
|
1240
|
-
return self.
|
|
1257
|
+
return self._action_agent.return_values + ["intermediate_steps"]
|
|
1241
1258
|
else:
|
|
1242
|
-
return self.
|
|
1259
|
+
return self._action_agent.return_values
|
|
1243
1260
|
|
|
1244
1261
|
def lookup_tool(self, name: str) -> BaseTool:
|
|
1245
1262
|
"""Lookup tool by name.
|
|
@@ -1339,7 +1356,7 @@ class AgentExecutor(Chain):
|
|
|
1339
1356
|
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
|
|
1340
1357
|
|
|
1341
1358
|
# Call the LLM to see what to do.
|
|
1342
|
-
output = self.
|
|
1359
|
+
output = self._action_agent.plan(
|
|
1343
1360
|
intermediate_steps,
|
|
1344
1361
|
callbacks=run_manager.get_child() if run_manager else None,
|
|
1345
1362
|
**inputs,
|
|
@@ -1372,7 +1389,7 @@ class AgentExecutor(Chain):
|
|
|
1372
1389
|
output = AgentAction("_Exception", observation, text)
|
|
1373
1390
|
if run_manager:
|
|
1374
1391
|
run_manager.on_agent_action(output, color="green")
|
|
1375
|
-
tool_run_kwargs = self.
|
|
1392
|
+
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
|
1376
1393
|
observation = ExceptionTool().run(
|
|
1377
1394
|
output.tool_input,
|
|
1378
1395
|
verbose=self.verbose,
|
|
@@ -1414,7 +1431,7 @@ class AgentExecutor(Chain):
|
|
|
1414
1431
|
tool = name_to_tool_map[agent_action.tool]
|
|
1415
1432
|
return_direct = tool.return_direct
|
|
1416
1433
|
color = color_mapping[agent_action.tool]
|
|
1417
|
-
tool_run_kwargs = self.
|
|
1434
|
+
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
|
1418
1435
|
if return_direct:
|
|
1419
1436
|
tool_run_kwargs["llm_prefix"] = ""
|
|
1420
1437
|
# We then call the tool on the tool input to get an observation
|
|
@@ -1426,7 +1443,7 @@ class AgentExecutor(Chain):
|
|
|
1426
1443
|
**tool_run_kwargs,
|
|
1427
1444
|
)
|
|
1428
1445
|
else:
|
|
1429
|
-
tool_run_kwargs = self.
|
|
1446
|
+
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
|
1430
1447
|
observation = InvalidTool().run(
|
|
1431
1448
|
{
|
|
1432
1449
|
"requested_tool_name": agent_action.tool,
|
|
@@ -1476,7 +1493,7 @@ class AgentExecutor(Chain):
|
|
|
1476
1493
|
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
|
|
1477
1494
|
|
|
1478
1495
|
# Call the LLM to see what to do.
|
|
1479
|
-
output = await self.
|
|
1496
|
+
output = await self._action_agent.aplan(
|
|
1480
1497
|
intermediate_steps,
|
|
1481
1498
|
callbacks=run_manager.get_child() if run_manager else None,
|
|
1482
1499
|
**inputs,
|
|
@@ -1507,7 +1524,7 @@ class AgentExecutor(Chain):
|
|
|
1507
1524
|
else:
|
|
1508
1525
|
raise ValueError("Got unexpected type of `handle_parsing_errors`")
|
|
1509
1526
|
output = AgentAction("_Exception", observation, text)
|
|
1510
|
-
tool_run_kwargs = self.
|
|
1527
|
+
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
|
1511
1528
|
observation = await ExceptionTool().arun(
|
|
1512
1529
|
output.tool_input,
|
|
1513
1530
|
verbose=self.verbose,
|
|
@@ -1561,7 +1578,7 @@ class AgentExecutor(Chain):
|
|
|
1561
1578
|
tool = name_to_tool_map[agent_action.tool]
|
|
1562
1579
|
return_direct = tool.return_direct
|
|
1563
1580
|
color = color_mapping[agent_action.tool]
|
|
1564
|
-
tool_run_kwargs = self.
|
|
1581
|
+
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
|
1565
1582
|
if return_direct:
|
|
1566
1583
|
tool_run_kwargs["llm_prefix"] = ""
|
|
1567
1584
|
# We then call the tool on the tool input to get an observation
|
|
@@ -1573,7 +1590,7 @@ class AgentExecutor(Chain):
|
|
|
1573
1590
|
**tool_run_kwargs,
|
|
1574
1591
|
)
|
|
1575
1592
|
else:
|
|
1576
|
-
tool_run_kwargs = self.
|
|
1593
|
+
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
|
1577
1594
|
observation = await InvalidTool().arun(
|
|
1578
1595
|
{
|
|
1579
1596
|
"requested_tool_name": agent_action.tool,
|
|
@@ -1628,7 +1645,7 @@ class AgentExecutor(Chain):
|
|
|
1628
1645
|
)
|
|
1629
1646
|
iterations += 1
|
|
1630
1647
|
time_elapsed = time.time() - start_time
|
|
1631
|
-
output = self.
|
|
1648
|
+
output = self._action_agent.return_stopped_response(
|
|
1632
1649
|
self.early_stopping_method, intermediate_steps, **inputs
|
|
1633
1650
|
)
|
|
1634
1651
|
return self._return(output, intermediate_steps, run_manager=run_manager)
|
|
@@ -1680,7 +1697,7 @@ class AgentExecutor(Chain):
|
|
|
1680
1697
|
|
|
1681
1698
|
iterations += 1
|
|
1682
1699
|
time_elapsed = time.time() - start_time
|
|
1683
|
-
output = self.
|
|
1700
|
+
output = self._action_agent.return_stopped_response(
|
|
1684
1701
|
self.early_stopping_method, intermediate_steps, **inputs
|
|
1685
1702
|
)
|
|
1686
1703
|
return await self._areturn(
|
|
@@ -1688,7 +1705,7 @@ class AgentExecutor(Chain):
|
|
|
1688
1705
|
)
|
|
1689
1706
|
except (TimeoutError, asyncio.TimeoutError):
|
|
1690
1707
|
# stop early when interrupted by the async timeout
|
|
1691
|
-
output = self.
|
|
1708
|
+
output = self._action_agent.return_stopped_response(
|
|
1692
1709
|
self.early_stopping_method, intermediate_steps, **inputs
|
|
1693
1710
|
)
|
|
1694
1711
|
return await self._areturn(
|
|
@@ -1702,8 +1719,8 @@ class AgentExecutor(Chain):
|
|
|
1702
1719
|
agent_action, observation = next_step_output
|
|
1703
1720
|
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
|
1704
1721
|
return_value_key = "output"
|
|
1705
|
-
if len(self.
|
|
1706
|
-
return_value_key = self.
|
|
1722
|
+
if len(self._action_agent.return_values) > 0:
|
|
1723
|
+
return_value_key = self._action_agent.return_values[0]
|
|
1707
1724
|
# Invalid tools won't be in the map, so we return False.
|
|
1708
1725
|
if agent_action.tool in name_to_tool_map:
|
|
1709
1726
|
if name_to_tool_map[agent_action.tool].return_direct:
|
|
@@ -371,7 +371,7 @@ class AgentExecutorIterator:
|
|
|
371
371
|
"""
|
|
372
372
|
logger.warning("Stopping agent prematurely due to triggering stop condition")
|
|
373
373
|
# this manually constructs agent finish with output key
|
|
374
|
-
output = self.agent_executor.
|
|
374
|
+
output = self.agent_executor._action_agent.return_stopped_response(
|
|
375
375
|
self.agent_executor.early_stopping_method,
|
|
376
376
|
self.intermediate_steps,
|
|
377
377
|
**self.inputs,
|
|
@@ -384,7 +384,7 @@ class AgentExecutorIterator:
|
|
|
384
384
|
the stopped response.
|
|
385
385
|
"""
|
|
386
386
|
logger.warning("Stopping agent prematurely due to triggering stop condition")
|
|
387
|
-
output = self.agent_executor.
|
|
387
|
+
output = self.agent_executor._action_agent.return_stopped_response(
|
|
388
388
|
self.agent_executor.early_stopping_method,
|
|
389
389
|
self.intermediate_steps,
|
|
390
390
|
**self.inputs,
|
langchain/agents/agent_types.py
CHANGED
|
@@ -272,7 +272,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|
|
272
272
|
instructions=instructions,
|
|
273
273
|
tools=[_get_assistants_tool(tool) for tool in tools], # type: ignore
|
|
274
274
|
model=model,
|
|
275
|
-
file_ids=kwargs.get("file_ids"),
|
|
276
275
|
)
|
|
277
276
|
return cls(assistant_id=assistant.id, client=client, **kwargs)
|
|
278
277
|
|
|
@@ -287,7 +286,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|
|
287
286
|
thread_id: Existing thread to use.
|
|
288
287
|
run_id: Existing run to use. Should only be supplied when providing
|
|
289
288
|
the tool output for a required action after an initial invocation.
|
|
290
|
-
file_ids: File ids to include in new run. Used for retrieval.
|
|
291
289
|
message_metadata: Metadata to associate with new message.
|
|
292
290
|
thread_metadata: Metadata to associate with new thread. Only relevant
|
|
293
291
|
when new thread being created.
|
|
@@ -327,7 +325,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|
|
327
325
|
{
|
|
328
326
|
"role": "user",
|
|
329
327
|
"content": input["content"],
|
|
330
|
-
"file_ids": input.get("file_ids", []),
|
|
331
328
|
"metadata": input.get("message_metadata"),
|
|
332
329
|
}
|
|
333
330
|
],
|
|
@@ -340,7 +337,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|
|
340
337
|
input["thread_id"],
|
|
341
338
|
content=input["content"],
|
|
342
339
|
role="user",
|
|
343
|
-
file_ids=input.get("file_ids", []),
|
|
344
340
|
metadata=input.get("message_metadata"),
|
|
345
341
|
)
|
|
346
342
|
run = self._create_run(input)
|
|
@@ -394,7 +390,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|
|
394
390
|
instructions=instructions,
|
|
395
391
|
tools=openai_tools, # type: ignore
|
|
396
392
|
model=model,
|
|
397
|
-
file_ids=kwargs.get("file_ids"),
|
|
398
393
|
)
|
|
399
394
|
return cls(assistant_id=assistant.id, async_client=async_client, **kwargs)
|
|
400
395
|
|
|
@@ -409,7 +404,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|
|
409
404
|
thread_id: Existing thread to use.
|
|
410
405
|
run_id: Existing run to use. Should only be supplied when providing
|
|
411
406
|
the tool output for a required action after an initial invocation.
|
|
412
|
-
file_ids: File ids to include in new run. Used for retrieval.
|
|
413
407
|
message_metadata: Metadata to associate with a new message.
|
|
414
408
|
thread_metadata: Metadata to associate with new thread. Only relevant
|
|
415
409
|
when a new thread is created.
|
|
@@ -439,7 +433,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|
|
439
433
|
try:
|
|
440
434
|
# Being run within AgentExecutor and there are tool outputs to submit.
|
|
441
435
|
if self.as_agent and input.get("intermediate_steps"):
|
|
442
|
-
tool_outputs = self.
|
|
436
|
+
tool_outputs = await self._aparse_intermediate_steps(
|
|
443
437
|
input["intermediate_steps"]
|
|
444
438
|
)
|
|
445
439
|
run = await self.async_client.beta.threads.runs.submit_tool_outputs(
|
|
@@ -452,7 +446,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|
|
452
446
|
{
|
|
453
447
|
"role": "user",
|
|
454
448
|
"content": input["content"],
|
|
455
|
-
"file_ids": input.get("file_ids", []),
|
|
456
449
|
"metadata": input.get("message_metadata"),
|
|
457
450
|
}
|
|
458
451
|
],
|
|
@@ -465,7 +458,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|
|
465
458
|
input["thread_id"],
|
|
466
459
|
content=input["content"],
|
|
467
460
|
role="user",
|
|
468
|
-
file_ids=input.get("file_ids", []),
|
|
469
461
|
metadata=input.get("message_metadata"),
|
|
470
462
|
)
|
|
471
463
|
run = await self._acreate_run(input)
|
|
@@ -493,9 +485,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|
|
493
485
|
) -> dict:
|
|
494
486
|
last_action, last_output = intermediate_steps[-1]
|
|
495
487
|
run = self._wait_for_run(last_action.run_id, last_action.thread_id)
|
|
496
|
-
required_tool_call_ids =
|
|
497
|
-
|
|
498
|
-
|
|
488
|
+
required_tool_call_ids = set()
|
|
489
|
+
if run.required_action:
|
|
490
|
+
required_tool_call_ids = {
|
|
491
|
+
tc.id for tc in run.required_action.submit_tool_outputs.tool_calls
|
|
492
|
+
}
|
|
499
493
|
tool_outputs = [
|
|
500
494
|
{"output": str(output), "tool_call_id": action.tool_call_id}
|
|
501
495
|
for action, output in intermediate_steps
|
|
@@ -621,9 +615,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|
|
621
615
|
) -> dict:
|
|
622
616
|
last_action, last_output = intermediate_steps[-1]
|
|
623
617
|
run = await self._wait_for_run(last_action.run_id, last_action.thread_id)
|
|
624
|
-
required_tool_call_ids =
|
|
625
|
-
|
|
626
|
-
|
|
618
|
+
required_tool_call_ids = set()
|
|
619
|
+
if run.required_action:
|
|
620
|
+
required_tool_call_ids = {
|
|
621
|
+
tc.id for tc in run.required_action.submit_tool_outputs.tool_calls
|
|
622
|
+
}
|
|
627
623
|
tool_outputs = [
|
|
628
624
|
{"output": str(output), "tool_call_id": action.tool_call_id}
|
|
629
625
|
for action, output in intermediate_steps
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Sequence
|
|
1
|
+
from typing import Optional, Sequence
|
|
2
2
|
|
|
3
3
|
from langchain_core.language_models import BaseLanguageModel
|
|
4
4
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
|
@@ -13,7 +13,10 @@ from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputP
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def create_openai_tools_agent(
|
|
16
|
-
llm: BaseLanguageModel,
|
|
16
|
+
llm: BaseLanguageModel,
|
|
17
|
+
tools: Sequence[BaseTool],
|
|
18
|
+
prompt: ChatPromptTemplate,
|
|
19
|
+
strict: Optional[bool] = None,
|
|
17
20
|
) -> Runnable:
|
|
18
21
|
"""Create an agent that uses OpenAI tools.
|
|
19
22
|
|
|
@@ -87,7 +90,9 @@ def create_openai_tools_agent(
|
|
|
87
90
|
if missing_vars:
|
|
88
91
|
raise ValueError(f"Prompt missing required variables: {missing_vars}")
|
|
89
92
|
|
|
90
|
-
llm_with_tools = llm.bind(
|
|
93
|
+
llm_with_tools = llm.bind(
|
|
94
|
+
tools=[convert_to_openai_tool(tool, strict=strict) for tool in tools]
|
|
95
|
+
)
|
|
91
96
|
|
|
92
97
|
agent = (
|
|
93
98
|
RunnablePassthrough.assign(
|
|
@@ -22,11 +22,11 @@ DOCUMENTS_KEY = "context"
|
|
|
22
22
|
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
def _validate_prompt(prompt: BasePromptTemplate) -> None:
|
|
26
|
-
if
|
|
25
|
+
def _validate_prompt(prompt: BasePromptTemplate, document_variable_name: str) -> None:
|
|
26
|
+
if document_variable_name not in prompt.input_variables:
|
|
27
27
|
raise ValueError(
|
|
28
|
-
f"Prompt must accept {
|
|
29
|
-
f"with input variables: {prompt.input_variables}"
|
|
28
|
+
f"Prompt must accept {document_variable_name} as an input variable. "
|
|
29
|
+
f"Received prompt with input variables: {prompt.input_variables}"
|
|
30
30
|
)
|
|
31
31
|
|
|
32
32
|
|
|
@@ -76,7 +76,7 @@ def create_stuff_documents_chain(
|
|
|
76
76
|
chain.invoke({"context": docs})
|
|
77
77
|
""" # noqa: E501
|
|
78
78
|
|
|
79
|
-
_validate_prompt(prompt)
|
|
79
|
+
_validate_prompt(prompt, document_variable_name)
|
|
80
80
|
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
|
|
81
81
|
_output_parser = output_parser or StrOutputParser()
|
|
82
82
|
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Dict, List, Optional
|
|
4
4
|
|
|
5
|
+
from langchain_core._api import deprecated
|
|
5
6
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
6
7
|
from langchain_core.language_models import BaseLanguageModel
|
|
7
8
|
from langchain_core.prompts import BasePromptTemplate
|
|
@@ -13,9 +14,151 @@ from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION
|
|
|
13
14
|
from langchain.chains.llm import LLMChain
|
|
14
15
|
|
|
15
16
|
|
|
17
|
+
@deprecated(
|
|
18
|
+
since="0.2.13",
|
|
19
|
+
message=(
|
|
20
|
+
"This class is deprecated and will be removed in langchain 1.0. "
|
|
21
|
+
"See API reference for replacement: "
|
|
22
|
+
"https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.base.ConstitutionalChain.html" # noqa: E501
|
|
23
|
+
),
|
|
24
|
+
removal="1.0",
|
|
25
|
+
)
|
|
16
26
|
class ConstitutionalChain(Chain):
|
|
17
27
|
"""Chain for applying constitutional principles.
|
|
18
28
|
|
|
29
|
+
Note: this class is deprecated. See below for a replacement implementation
|
|
30
|
+
using LangGraph. The benefits of this implementation are:
|
|
31
|
+
|
|
32
|
+
- Uses LLM tool calling features instead of parsing string responses;
|
|
33
|
+
- Support for both token-by-token and step-by-step streaming;
|
|
34
|
+
- Support for checkpointing and memory of chat history;
|
|
35
|
+
- Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
|
|
36
|
+
|
|
37
|
+
Install LangGraph with:
|
|
38
|
+
|
|
39
|
+
.. code-block:: bash
|
|
40
|
+
|
|
41
|
+
pip install -U langgraph
|
|
42
|
+
|
|
43
|
+
.. code-block:: python
|
|
44
|
+
|
|
45
|
+
from typing import List, Optional, Tuple
|
|
46
|
+
|
|
47
|
+
from langchain.chains.constitutional_ai.prompts import (
|
|
48
|
+
CRITIQUE_PROMPT,
|
|
49
|
+
REVISION_PROMPT,
|
|
50
|
+
)
|
|
51
|
+
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
|
52
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
53
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
54
|
+
from langchain_openai import ChatOpenAI
|
|
55
|
+
from langgraph.graph import END, START, StateGraph
|
|
56
|
+
from typing_extensions import Annotated, TypedDict
|
|
57
|
+
|
|
58
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
|
59
|
+
|
|
60
|
+
class Critique(TypedDict):
|
|
61
|
+
\"\"\"Generate a critique, if needed.\"\"\"
|
|
62
|
+
critique_needed: Annotated[bool, ..., "Whether or not a critique is needed."]
|
|
63
|
+
critique: Annotated[str, ..., "If needed, the critique."]
|
|
64
|
+
|
|
65
|
+
critique_prompt = ChatPromptTemplate.from_template(
|
|
66
|
+
"Critique this response according to the critique request. "
|
|
67
|
+
"If no critique is needed, specify that.\\n\\n"
|
|
68
|
+
"Query: {query}\\n\\n"
|
|
69
|
+
"Response: {response}\\n\\n"
|
|
70
|
+
"Critique request: {critique_request}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
revision_prompt = ChatPromptTemplate.from_template(
|
|
74
|
+
"Revise this response according to the critique and reivsion request.\\n\\n"
|
|
75
|
+
"Query: {query}\\n\\n"
|
|
76
|
+
"Response: {response}\\n\\n"
|
|
77
|
+
"Critique request: {critique_request}\\n\\n"
|
|
78
|
+
"Critique: {critique}\\n\\n"
|
|
79
|
+
"If the critique does not identify anything worth changing, ignore the "
|
|
80
|
+
"revision request and return 'No revisions needed'. If the critique "
|
|
81
|
+
"does identify something worth changing, revise the response based on "
|
|
82
|
+
"the revision request.\\n\\n"
|
|
83
|
+
"Revision Request: {revision_request}"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
chain = llm | StrOutputParser()
|
|
87
|
+
critique_chain = critique_prompt | llm.with_structured_output(Critique)
|
|
88
|
+
revision_chain = revision_prompt | llm | StrOutputParser()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class State(TypedDict):
|
|
92
|
+
query: str
|
|
93
|
+
constitutional_principles: List[ConstitutionalPrinciple]
|
|
94
|
+
initial_response: str
|
|
95
|
+
critiques_and_revisions: List[Tuple[str, str]]
|
|
96
|
+
response: str
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
async def generate_response(state: State):
|
|
100
|
+
\"\"\"Generate initial response.\"\"\"
|
|
101
|
+
response = await chain.ainvoke(state["query"])
|
|
102
|
+
return {"response": response, "initial_response": response}
|
|
103
|
+
|
|
104
|
+
async def critique_and_revise(state: State):
|
|
105
|
+
\"\"\"Critique and revise response according to principles.\"\"\"
|
|
106
|
+
critiques_and_revisions = []
|
|
107
|
+
response = state["initial_response"]
|
|
108
|
+
for principle in state["constitutional_principles"]:
|
|
109
|
+
critique = await critique_chain.ainvoke(
|
|
110
|
+
{
|
|
111
|
+
"query": state["query"],
|
|
112
|
+
"response": response,
|
|
113
|
+
"critique_request": principle.critique_request,
|
|
114
|
+
}
|
|
115
|
+
)
|
|
116
|
+
if critique["critique_needed"]:
|
|
117
|
+
revision = await revision_chain.ainvoke(
|
|
118
|
+
{
|
|
119
|
+
"query": state["query"],
|
|
120
|
+
"response": response,
|
|
121
|
+
"critique_request": principle.critique_request,
|
|
122
|
+
"critique": critique["critique"],
|
|
123
|
+
"revision_request": principle.revision_request,
|
|
124
|
+
}
|
|
125
|
+
)
|
|
126
|
+
response = revision
|
|
127
|
+
critiques_and_revisions.append((critique["critique"], revision))
|
|
128
|
+
else:
|
|
129
|
+
critiques_and_revisions.append((critique["critique"], ""))
|
|
130
|
+
return {
|
|
131
|
+
"critiques_and_revisions": critiques_and_revisions,
|
|
132
|
+
"response": response,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
graph = StateGraph(State)
|
|
136
|
+
graph.add_node("generate_response", generate_response)
|
|
137
|
+
graph.add_node("critique_and_revise", critique_and_revise)
|
|
138
|
+
|
|
139
|
+
graph.add_edge(START, "generate_response")
|
|
140
|
+
graph.add_edge("generate_response", "critique_and_revise")
|
|
141
|
+
graph.add_edge("critique_and_revise", END)
|
|
142
|
+
app = graph.compile()
|
|
143
|
+
|
|
144
|
+
.. code-block:: python
|
|
145
|
+
|
|
146
|
+
constitutional_principles=[
|
|
147
|
+
ConstitutionalPrinciple(
|
|
148
|
+
critique_request="Tell if this answer is good.",
|
|
149
|
+
revision_request="Give a better answer.",
|
|
150
|
+
)
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
query = "What is the meaning of life? Answer in 10 words or fewer."
|
|
154
|
+
|
|
155
|
+
async for step in app.astream(
|
|
156
|
+
{"query": query, "constitutional_principles": constitutional_principles},
|
|
157
|
+
stream_mode="values",
|
|
158
|
+
):
|
|
159
|
+
subset = ["initial_response", "critiques_and_revisions", "response"]
|
|
160
|
+
print({k: v for k, v in step.items() if k in subset})
|
|
161
|
+
|
|
19
162
|
Example:
|
|
20
163
|
.. code-block:: python
|
|
21
164
|
|
|
@@ -44,7 +187,7 @@ class ConstitutionalChain(Chain):
|
|
|
44
187
|
)
|
|
45
188
|
|
|
46
189
|
constitutional_chain.run(question="What is the meaning of life?")
|
|
47
|
-
"""
|
|
190
|
+
""" # noqa: E501
|
|
48
191
|
|
|
49
192
|
chain: LLMChain
|
|
50
193
|
constitutional_principles: List[ConstitutionalPrinciple]
|
|
@@ -16,7 +16,7 @@ from langchain.memory.buffer import ConversationBufferMemory
|
|
|
16
16
|
since="0.2.7",
|
|
17
17
|
alternative=(
|
|
18
18
|
"RunnableWithMessageHistory: "
|
|
19
|
-
"https://
|
|
19
|
+
"https://python.langchain.com/v0.2/api_reference/core/runnables/langchain_core.runnables.history.RunnableWithMessageHistory.html" # noqa: E501
|
|
20
20
|
),
|
|
21
21
|
removal="1.0",
|
|
22
22
|
)
|