langchain 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langchain/agents/agent.py CHANGED
@@ -19,6 +19,7 @@ from typing import (
19
19
  Sequence,
20
20
  Tuple,
21
21
  Union,
22
+ cast,
22
23
  )
23
24
 
24
25
  import yaml
@@ -629,7 +630,7 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
629
630
 
630
631
  @deprecated(
631
632
  "0.1.0",
632
- alternative=(
633
+ message=(
633
634
  "Use new agent constructor methods like create_react_agent, create_json_agent, "
634
635
  "create_structured_chat_agent, etc."
635
636
  ),
@@ -720,7 +721,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
720
721
 
721
722
  @deprecated(
722
723
  "0.1.0",
723
- alternative=(
724
+ message=(
724
725
  "Use new agent constructor methods like create_react_agent, create_json_agent, "
725
726
  "create_structured_chat_agent, etc."
726
727
  ),
@@ -1042,12 +1043,13 @@ class ExceptionTool(BaseTool):
1042
1043
 
1043
1044
 
1044
1045
  NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]]
1046
+ RunnableAgentType = Union[RunnableAgent, RunnableMultiActionAgent]
1045
1047
 
1046
1048
 
1047
1049
  class AgentExecutor(Chain):
1048
1050
  """Agent that is using tools."""
1049
1051
 
1050
- agent: Union[BaseSingleActionAgent, BaseMultiActionAgent]
1052
+ agent: Union[BaseSingleActionAgent, BaseMultiActionAgent, Runnable]
1051
1053
  """The agent to run for creating a plan and determining actions
1052
1054
  to take at each step of the execution loop."""
1053
1055
  tools: Sequence[BaseTool]
@@ -1095,7 +1097,7 @@ class AgentExecutor(Chain):
1095
1097
  @classmethod
1096
1098
  def from_agent_and_tools(
1097
1099
  cls,
1098
- agent: Union[BaseSingleActionAgent, BaseMultiActionAgent],
1100
+ agent: Union[BaseSingleActionAgent, BaseMultiActionAgent, Runnable],
1099
1101
  tools: Sequence[BaseTool],
1100
1102
  callbacks: Callbacks = None,
1101
1103
  **kwargs: Any,
@@ -1172,6 +1174,21 @@ class AgentExecutor(Chain):
1172
1174
  )
1173
1175
  return values
1174
1176
 
1177
+ @property
1178
+ def _action_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
1179
+ """Type cast self.agent.
1180
+
1181
+ The .agent attribute type includes Runnable, but is converted to one of
1182
+ RunnableAgentType in the validate_runnable_agent root_validator.
1183
+
1184
+ To support instantiating with a Runnable, here we explicitly cast the type
1185
+ to reflect the changes made in the root_validator.
1186
+ """
1187
+ if isinstance(self.agent, Runnable):
1188
+ return cast(RunnableAgentType, self.agent)
1189
+ else:
1190
+ return self.agent
1191
+
1175
1192
  def save(self, file_path: Union[Path, str]) -> None:
1176
1193
  """Raise error - saving not supported for Agent Executors.
1177
1194
 
@@ -1193,7 +1210,7 @@ class AgentExecutor(Chain):
1193
1210
  Args:
1194
1211
  file_path: Path to save to.
1195
1212
  """
1196
- return self.agent.save(file_path)
1213
+ return self._action_agent.save(file_path)
1197
1214
 
1198
1215
  def iter(
1199
1216
  self,
@@ -1228,7 +1245,7 @@ class AgentExecutor(Chain):
1228
1245
 
1229
1246
  :meta private:
1230
1247
  """
1231
- return self.agent.input_keys
1248
+ return self._action_agent.input_keys
1232
1249
 
1233
1250
  @property
1234
1251
  def output_keys(self) -> List[str]:
@@ -1237,9 +1254,9 @@ class AgentExecutor(Chain):
1237
1254
  :meta private:
1238
1255
  """
1239
1256
  if self.return_intermediate_steps:
1240
- return self.agent.return_values + ["intermediate_steps"]
1257
+ return self._action_agent.return_values + ["intermediate_steps"]
1241
1258
  else:
1242
- return self.agent.return_values
1259
+ return self._action_agent.return_values
1243
1260
 
1244
1261
  def lookup_tool(self, name: str) -> BaseTool:
1245
1262
  """Lookup tool by name.
@@ -1339,7 +1356,7 @@ class AgentExecutor(Chain):
1339
1356
  intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
1340
1357
 
1341
1358
  # Call the LLM to see what to do.
1342
- output = self.agent.plan(
1359
+ output = self._action_agent.plan(
1343
1360
  intermediate_steps,
1344
1361
  callbacks=run_manager.get_child() if run_manager else None,
1345
1362
  **inputs,
@@ -1372,7 +1389,7 @@ class AgentExecutor(Chain):
1372
1389
  output = AgentAction("_Exception", observation, text)
1373
1390
  if run_manager:
1374
1391
  run_manager.on_agent_action(output, color="green")
1375
- tool_run_kwargs = self.agent.tool_run_logging_kwargs()
1392
+ tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
1376
1393
  observation = ExceptionTool().run(
1377
1394
  output.tool_input,
1378
1395
  verbose=self.verbose,
@@ -1414,7 +1431,7 @@ class AgentExecutor(Chain):
1414
1431
  tool = name_to_tool_map[agent_action.tool]
1415
1432
  return_direct = tool.return_direct
1416
1433
  color = color_mapping[agent_action.tool]
1417
- tool_run_kwargs = self.agent.tool_run_logging_kwargs()
1434
+ tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
1418
1435
  if return_direct:
1419
1436
  tool_run_kwargs["llm_prefix"] = ""
1420
1437
  # We then call the tool on the tool input to get an observation
@@ -1426,7 +1443,7 @@ class AgentExecutor(Chain):
1426
1443
  **tool_run_kwargs,
1427
1444
  )
1428
1445
  else:
1429
- tool_run_kwargs = self.agent.tool_run_logging_kwargs()
1446
+ tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
1430
1447
  observation = InvalidTool().run(
1431
1448
  {
1432
1449
  "requested_tool_name": agent_action.tool,
@@ -1476,7 +1493,7 @@ class AgentExecutor(Chain):
1476
1493
  intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
1477
1494
 
1478
1495
  # Call the LLM to see what to do.
1479
- output = await self.agent.aplan(
1496
+ output = await self._action_agent.aplan(
1480
1497
  intermediate_steps,
1481
1498
  callbacks=run_manager.get_child() if run_manager else None,
1482
1499
  **inputs,
@@ -1507,7 +1524,7 @@ class AgentExecutor(Chain):
1507
1524
  else:
1508
1525
  raise ValueError("Got unexpected type of `handle_parsing_errors`")
1509
1526
  output = AgentAction("_Exception", observation, text)
1510
- tool_run_kwargs = self.agent.tool_run_logging_kwargs()
1527
+ tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
1511
1528
  observation = await ExceptionTool().arun(
1512
1529
  output.tool_input,
1513
1530
  verbose=self.verbose,
@@ -1561,7 +1578,7 @@ class AgentExecutor(Chain):
1561
1578
  tool = name_to_tool_map[agent_action.tool]
1562
1579
  return_direct = tool.return_direct
1563
1580
  color = color_mapping[agent_action.tool]
1564
- tool_run_kwargs = self.agent.tool_run_logging_kwargs()
1581
+ tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
1565
1582
  if return_direct:
1566
1583
  tool_run_kwargs["llm_prefix"] = ""
1567
1584
  # We then call the tool on the tool input to get an observation
@@ -1573,7 +1590,7 @@ class AgentExecutor(Chain):
1573
1590
  **tool_run_kwargs,
1574
1591
  )
1575
1592
  else:
1576
- tool_run_kwargs = self.agent.tool_run_logging_kwargs()
1593
+ tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
1577
1594
  observation = await InvalidTool().arun(
1578
1595
  {
1579
1596
  "requested_tool_name": agent_action.tool,
@@ -1628,7 +1645,7 @@ class AgentExecutor(Chain):
1628
1645
  )
1629
1646
  iterations += 1
1630
1647
  time_elapsed = time.time() - start_time
1631
- output = self.agent.return_stopped_response(
1648
+ output = self._action_agent.return_stopped_response(
1632
1649
  self.early_stopping_method, intermediate_steps, **inputs
1633
1650
  )
1634
1651
  return self._return(output, intermediate_steps, run_manager=run_manager)
@@ -1680,7 +1697,7 @@ class AgentExecutor(Chain):
1680
1697
 
1681
1698
  iterations += 1
1682
1699
  time_elapsed = time.time() - start_time
1683
- output = self.agent.return_stopped_response(
1700
+ output = self._action_agent.return_stopped_response(
1684
1701
  self.early_stopping_method, intermediate_steps, **inputs
1685
1702
  )
1686
1703
  return await self._areturn(
@@ -1688,7 +1705,7 @@ class AgentExecutor(Chain):
1688
1705
  )
1689
1706
  except (TimeoutError, asyncio.TimeoutError):
1690
1707
  # stop early when interrupted by the async timeout
1691
- output = self.agent.return_stopped_response(
1708
+ output = self._action_agent.return_stopped_response(
1692
1709
  self.early_stopping_method, intermediate_steps, **inputs
1693
1710
  )
1694
1711
  return await self._areturn(
@@ -1702,8 +1719,8 @@ class AgentExecutor(Chain):
1702
1719
  agent_action, observation = next_step_output
1703
1720
  name_to_tool_map = {tool.name: tool for tool in self.tools}
1704
1721
  return_value_key = "output"
1705
- if len(self.agent.return_values) > 0:
1706
- return_value_key = self.agent.return_values[0]
1722
+ if len(self._action_agent.return_values) > 0:
1723
+ return_value_key = self._action_agent.return_values[0]
1707
1724
  # Invalid tools won't be in the map, so we return False.
1708
1725
  if agent_action.tool in name_to_tool_map:
1709
1726
  if name_to_tool_map[agent_action.tool].return_direct:
@@ -371,7 +371,7 @@ class AgentExecutorIterator:
371
371
  """
372
372
  logger.warning("Stopping agent prematurely due to triggering stop condition")
373
373
  # this manually constructs agent finish with output key
374
- output = self.agent_executor.agent.return_stopped_response(
374
+ output = self.agent_executor._action_agent.return_stopped_response(
375
375
  self.agent_executor.early_stopping_method,
376
376
  self.intermediate_steps,
377
377
  **self.inputs,
@@ -384,7 +384,7 @@ class AgentExecutorIterator:
384
384
  the stopped response.
385
385
  """
386
386
  logger.warning("Stopping agent prematurely due to triggering stop condition")
387
- output = self.agent_executor.agent.return_stopped_response(
387
+ output = self.agent_executor._action_agent.return_stopped_response(
388
388
  self.agent_executor.early_stopping_method,
389
389
  self.intermediate_steps,
390
390
  **self.inputs,
@@ -7,7 +7,7 @@ from langchain_core._api import deprecated
7
7
 
8
8
  @deprecated(
9
9
  "0.1.0",
10
- alternative=(
10
+ message=(
11
11
  "Use new agent constructor methods like create_react_agent, create_json_agent, "
12
12
  "create_structured_chat_agent, etc."
13
13
  ),
@@ -272,7 +272,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
272
272
  instructions=instructions,
273
273
  tools=[_get_assistants_tool(tool) for tool in tools], # type: ignore
274
274
  model=model,
275
- file_ids=kwargs.get("file_ids"),
276
275
  )
277
276
  return cls(assistant_id=assistant.id, client=client, **kwargs)
278
277
 
@@ -287,7 +286,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
287
286
  thread_id: Existing thread to use.
288
287
  run_id: Existing run to use. Should only be supplied when providing
289
288
  the tool output for a required action after an initial invocation.
290
- file_ids: File ids to include in new run. Used for retrieval.
291
289
  message_metadata: Metadata to associate with new message.
292
290
  thread_metadata: Metadata to associate with new thread. Only relevant
293
291
  when new thread being created.
@@ -327,7 +325,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
327
325
  {
328
326
  "role": "user",
329
327
  "content": input["content"],
330
- "file_ids": input.get("file_ids", []),
331
328
  "metadata": input.get("message_metadata"),
332
329
  }
333
330
  ],
@@ -340,7 +337,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
340
337
  input["thread_id"],
341
338
  content=input["content"],
342
339
  role="user",
343
- file_ids=input.get("file_ids", []),
344
340
  metadata=input.get("message_metadata"),
345
341
  )
346
342
  run = self._create_run(input)
@@ -394,7 +390,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
394
390
  instructions=instructions,
395
391
  tools=openai_tools, # type: ignore
396
392
  model=model,
397
- file_ids=kwargs.get("file_ids"),
398
393
  )
399
394
  return cls(assistant_id=assistant.id, async_client=async_client, **kwargs)
400
395
 
@@ -409,7 +404,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
409
404
  thread_id: Existing thread to use.
410
405
  run_id: Existing run to use. Should only be supplied when providing
411
406
  the tool output for a required action after an initial invocation.
412
- file_ids: File ids to include in new run. Used for retrieval.
413
407
  message_metadata: Metadata to associate with a new message.
414
408
  thread_metadata: Metadata to associate with new thread. Only relevant
415
409
  when a new thread is created.
@@ -439,7 +433,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
439
433
  try:
440
434
  # Being run within AgentExecutor and there are tool outputs to submit.
441
435
  if self.as_agent and input.get("intermediate_steps"):
442
- tool_outputs = self._parse_intermediate_steps(
436
+ tool_outputs = await self._aparse_intermediate_steps(
443
437
  input["intermediate_steps"]
444
438
  )
445
439
  run = await self.async_client.beta.threads.runs.submit_tool_outputs(
@@ -452,7 +446,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
452
446
  {
453
447
  "role": "user",
454
448
  "content": input["content"],
455
- "file_ids": input.get("file_ids", []),
456
449
  "metadata": input.get("message_metadata"),
457
450
  }
458
451
  ],
@@ -465,7 +458,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
465
458
  input["thread_id"],
466
459
  content=input["content"],
467
460
  role="user",
468
- file_ids=input.get("file_ids", []),
469
461
  metadata=input.get("message_metadata"),
470
462
  )
471
463
  run = await self._acreate_run(input)
@@ -493,9 +485,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
493
485
  ) -> dict:
494
486
  last_action, last_output = intermediate_steps[-1]
495
487
  run = self._wait_for_run(last_action.run_id, last_action.thread_id)
496
- required_tool_call_ids = {
497
- tc.id for tc in run.required_action.submit_tool_outputs.tool_calls
498
- }
488
+ required_tool_call_ids = set()
489
+ if run.required_action:
490
+ required_tool_call_ids = {
491
+ tc.id for tc in run.required_action.submit_tool_outputs.tool_calls
492
+ }
499
493
  tool_outputs = [
500
494
  {"output": str(output), "tool_call_id": action.tool_call_id}
501
495
  for action, output in intermediate_steps
@@ -621,9 +615,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
621
615
  ) -> dict:
622
616
  last_action, last_output = intermediate_steps[-1]
623
617
  run = await self._wait_for_run(last_action.run_id, last_action.thread_id)
624
- required_tool_call_ids = {
625
- tc.id for tc in run.required_action.submit_tool_outputs.tool_calls
626
- }
618
+ required_tool_call_ids = set()
619
+ if run.required_action:
620
+ required_tool_call_ids = {
621
+ tc.id for tc in run.required_action.submit_tool_outputs.tool_calls
622
+ }
627
623
  tool_outputs = [
628
624
  {"output": str(output), "tool_call_id": action.tool_call_id}
629
625
  for action, output in intermediate_steps
@@ -1,4 +1,4 @@
1
- from typing import Sequence
1
+ from typing import Optional, Sequence
2
2
 
3
3
  from langchain_core.language_models import BaseLanguageModel
4
4
  from langchain_core.prompts.chat import ChatPromptTemplate
@@ -13,7 +13,10 @@ from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputP
13
13
 
14
14
 
15
15
  def create_openai_tools_agent(
16
- llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate
16
+ llm: BaseLanguageModel,
17
+ tools: Sequence[BaseTool],
18
+ prompt: ChatPromptTemplate,
19
+ strict: Optional[bool] = None,
17
20
  ) -> Runnable:
18
21
  """Create an agent that uses OpenAI tools.
19
22
 
@@ -87,7 +90,9 @@ def create_openai_tools_agent(
87
90
  if missing_vars:
88
91
  raise ValueError(f"Prompt missing required variables: {missing_vars}")
89
92
 
90
- llm_with_tools = llm.bind(tools=[convert_to_openai_tool(tool) for tool in tools])
93
+ llm_with_tools = llm.bind(
94
+ tools=[convert_to_openai_tool(tool, strict=strict) for tool in tools]
95
+ )
91
96
 
92
97
  agent = (
93
98
  RunnablePassthrough.assign(
@@ -22,11 +22,11 @@ DOCUMENTS_KEY = "context"
22
22
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
23
23
 
24
24
 
25
- def _validate_prompt(prompt: BasePromptTemplate) -> None:
26
- if DOCUMENTS_KEY not in prompt.input_variables:
25
+ def _validate_prompt(prompt: BasePromptTemplate, document_variable_name: str) -> None:
26
+ if document_variable_name not in prompt.input_variables:
27
27
  raise ValueError(
28
- f"Prompt must accept {DOCUMENTS_KEY} as an input variable. Received prompt "
29
- f"with input variables: {prompt.input_variables}"
28
+ f"Prompt must accept {document_variable_name} as an input variable. "
29
+ f"Received prompt with input variables: {prompt.input_variables}"
30
30
  )
31
31
 
32
32
 
@@ -76,7 +76,7 @@ def create_stuff_documents_chain(
76
76
  chain.invoke({"context": docs})
77
77
  """ # noqa: E501
78
78
 
79
- _validate_prompt(prompt)
79
+ _validate_prompt(prompt, document_variable_name)
80
80
  _document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
81
81
  _output_parser = output_parser or StrOutputParser()
82
82
 
@@ -2,6 +2,7 @@
2
2
 
3
3
  from typing import Any, Dict, List, Optional
4
4
 
5
+ from langchain_core._api import deprecated
5
6
  from langchain_core.callbacks import CallbackManagerForChainRun
6
7
  from langchain_core.language_models import BaseLanguageModel
7
8
  from langchain_core.prompts import BasePromptTemplate
@@ -13,9 +14,151 @@ from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION
13
14
  from langchain.chains.llm import LLMChain
14
15
 
15
16
 
17
+ @deprecated(
18
+ since="0.2.13",
19
+ message=(
20
+ "This class is deprecated and will be removed in langchain 1.0. "
21
+ "See API reference for replacement: "
22
+ "https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.base.ConstitutionalChain.html" # noqa: E501
23
+ ),
24
+ removal="1.0",
25
+ )
16
26
  class ConstitutionalChain(Chain):
17
27
  """Chain for applying constitutional principles.
18
28
 
29
+ Note: this class is deprecated. See below for a replacement implementation
30
+ using LangGraph. The benefits of this implementation are:
31
+
32
+ - Uses LLM tool calling features instead of parsing string responses;
33
+ - Support for both token-by-token and step-by-step streaming;
34
+ - Support for checkpointing and memory of chat history;
35
+ - Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
36
+
37
+ Install LangGraph with:
38
+
39
+ .. code-block:: bash
40
+
41
+ pip install -U langgraph
42
+
43
+ .. code-block:: python
44
+
45
+ from typing import List, Optional, Tuple
46
+
47
+ from langchain.chains.constitutional_ai.prompts import (
48
+ CRITIQUE_PROMPT,
49
+ REVISION_PROMPT,
50
+ )
51
+ from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
52
+ from langchain_core.output_parsers import StrOutputParser
53
+ from langchain_core.prompts import ChatPromptTemplate
54
+ from langchain_openai import ChatOpenAI
55
+ from langgraph.graph import END, START, StateGraph
56
+ from typing_extensions import Annotated, TypedDict
57
+
58
+ llm = ChatOpenAI(model="gpt-4o-mini")
59
+
60
+ class Critique(TypedDict):
61
+ \"\"\"Generate a critique, if needed.\"\"\"
62
+ critique_needed: Annotated[bool, ..., "Whether or not a critique is needed."]
63
+ critique: Annotated[str, ..., "If needed, the critique."]
64
+
65
+ critique_prompt = ChatPromptTemplate.from_template(
66
+ "Critique this response according to the critique request. "
67
+ "If no critique is needed, specify that.\\n\\n"
68
+ "Query: {query}\\n\\n"
69
+ "Response: {response}\\n\\n"
70
+ "Critique request: {critique_request}"
71
+ )
72
+
73
+ revision_prompt = ChatPromptTemplate.from_template(
74
+ "Revise this response according to the critique and reivsion request.\\n\\n"
75
+ "Query: {query}\\n\\n"
76
+ "Response: {response}\\n\\n"
77
+ "Critique request: {critique_request}\\n\\n"
78
+ "Critique: {critique}\\n\\n"
79
+ "If the critique does not identify anything worth changing, ignore the "
80
+ "revision request and return 'No revisions needed'. If the critique "
81
+ "does identify something worth changing, revise the response based on "
82
+ "the revision request.\\n\\n"
83
+ "Revision Request: {revision_request}"
84
+ )
85
+
86
+ chain = llm | StrOutputParser()
87
+ critique_chain = critique_prompt | llm.with_structured_output(Critique)
88
+ revision_chain = revision_prompt | llm | StrOutputParser()
89
+
90
+
91
+ class State(TypedDict):
92
+ query: str
93
+ constitutional_principles: List[ConstitutionalPrinciple]
94
+ initial_response: str
95
+ critiques_and_revisions: List[Tuple[str, str]]
96
+ response: str
97
+
98
+
99
+ async def generate_response(state: State):
100
+ \"\"\"Generate initial response.\"\"\"
101
+ response = await chain.ainvoke(state["query"])
102
+ return {"response": response, "initial_response": response}
103
+
104
+ async def critique_and_revise(state: State):
105
+ \"\"\"Critique and revise response according to principles.\"\"\"
106
+ critiques_and_revisions = []
107
+ response = state["initial_response"]
108
+ for principle in state["constitutional_principles"]:
109
+ critique = await critique_chain.ainvoke(
110
+ {
111
+ "query": state["query"],
112
+ "response": response,
113
+ "critique_request": principle.critique_request,
114
+ }
115
+ )
116
+ if critique["critique_needed"]:
117
+ revision = await revision_chain.ainvoke(
118
+ {
119
+ "query": state["query"],
120
+ "response": response,
121
+ "critique_request": principle.critique_request,
122
+ "critique": critique["critique"],
123
+ "revision_request": principle.revision_request,
124
+ }
125
+ )
126
+ response = revision
127
+ critiques_and_revisions.append((critique["critique"], revision))
128
+ else:
129
+ critiques_and_revisions.append((critique["critique"], ""))
130
+ return {
131
+ "critiques_and_revisions": critiques_and_revisions,
132
+ "response": response,
133
+ }
134
+
135
+ graph = StateGraph(State)
136
+ graph.add_node("generate_response", generate_response)
137
+ graph.add_node("critique_and_revise", critique_and_revise)
138
+
139
+ graph.add_edge(START, "generate_response")
140
+ graph.add_edge("generate_response", "critique_and_revise")
141
+ graph.add_edge("critique_and_revise", END)
142
+ app = graph.compile()
143
+
144
+ .. code-block:: python
145
+
146
+ constitutional_principles=[
147
+ ConstitutionalPrinciple(
148
+ critique_request="Tell if this answer is good.",
149
+ revision_request="Give a better answer.",
150
+ )
151
+ ]
152
+
153
+ query = "What is the meaning of life? Answer in 10 words or fewer."
154
+
155
+ async for step in app.astream(
156
+ {"query": query, "constitutional_principles": constitutional_principles},
157
+ stream_mode="values",
158
+ ):
159
+ subset = ["initial_response", "critiques_and_revisions", "response"]
160
+ print({k: v for k, v in step.items() if k in subset})
161
+
19
162
  Example:
20
163
  .. code-block:: python
21
164
 
@@ -44,7 +187,7 @@ class ConstitutionalChain(Chain):
44
187
  )
45
188
 
46
189
  constitutional_chain.run(question="What is the meaning of life?")
47
- """
190
+ """ # noqa: E501
48
191
 
49
192
  chain: LLMChain
50
193
  constitutional_principles: List[ConstitutionalPrinciple]
@@ -16,7 +16,7 @@ from langchain.memory.buffer import ConversationBufferMemory
16
16
  since="0.2.7",
17
17
  alternative=(
18
18
  "RunnableWithMessageHistory: "
19
- "https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.history.RunnableWithMessageHistory.html" # noqa: E501
19
+ "https://python.langchain.com/v0.2/api_reference/core/runnables/langchain_core.runnables.history.RunnableWithMessageHistory.html" # noqa: E501
20
20
  ),
21
21
  removal="1.0",
22
22
  )