langchain-core 0.3.72__py3-none-any.whl → 0.4.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. langchain_core/_api/beta_decorator.py +1 -0
  2. langchain_core/_api/deprecation.py +2 -0
  3. langchain_core/beta/runnables/context.py +1 -0
  4. langchain_core/callbacks/base.py +23 -14
  5. langchain_core/callbacks/file.py +1 -0
  6. langchain_core/callbacks/manager.py +145 -19
  7. langchain_core/callbacks/streaming_stdout.py +4 -3
  8. langchain_core/callbacks/usage.py +15 -3
  9. langchain_core/chat_history.py +1 -0
  10. langchain_core/document_loaders/langsmith.py +2 -1
  11. langchain_core/documents/base.py +2 -0
  12. langchain_core/embeddings/fake.py +2 -0
  13. langchain_core/indexing/api.py +10 -0
  14. langchain_core/language_models/_utils.py +37 -0
  15. langchain_core/language_models/base.py +4 -1
  16. langchain_core/language_models/chat_models.py +48 -27
  17. langchain_core/language_models/fake_chat_models.py +71 -1
  18. langchain_core/language_models/llms.py +1 -0
  19. langchain_core/memory.py +1 -0
  20. langchain_core/messages/__init__.py +54 -0
  21. langchain_core/messages/ai.py +31 -18
  22. langchain_core/messages/content_blocks.py +1349 -69
  23. langchain_core/messages/human.py +1 -0
  24. langchain_core/messages/modifier.py +1 -1
  25. langchain_core/messages/tool.py +8 -83
  26. langchain_core/messages/utils.py +221 -6
  27. langchain_core/output_parsers/base.py +51 -14
  28. langchain_core/output_parsers/json.py +5 -2
  29. langchain_core/output_parsers/list.py +7 -2
  30. langchain_core/output_parsers/openai_functions.py +29 -5
  31. langchain_core/output_parsers/openai_tools.py +90 -47
  32. langchain_core/output_parsers/pydantic.py +3 -2
  33. langchain_core/output_parsers/transform.py +53 -12
  34. langchain_core/output_parsers/xml.py +14 -5
  35. langchain_core/outputs/llm_result.py +4 -1
  36. langchain_core/prompt_values.py +111 -7
  37. langchain_core/prompts/base.py +4 -0
  38. langchain_core/prompts/chat.py +3 -0
  39. langchain_core/prompts/few_shot.py +1 -0
  40. langchain_core/prompts/few_shot_with_templates.py +1 -0
  41. langchain_core/prompts/image.py +1 -0
  42. langchain_core/prompts/pipeline.py +1 -0
  43. langchain_core/prompts/prompt.py +1 -0
  44. langchain_core/prompts/structured.py +1 -0
  45. langchain_core/rate_limiters.py +1 -0
  46. langchain_core/retrievers.py +3 -0
  47. langchain_core/runnables/base.py +75 -57
  48. langchain_core/runnables/branch.py +1 -0
  49. langchain_core/runnables/config.py +2 -2
  50. langchain_core/runnables/configurable.py +2 -1
  51. langchain_core/runnables/fallbacks.py +3 -7
  52. langchain_core/runnables/graph.py +5 -3
  53. langchain_core/runnables/graph_ascii.py +1 -0
  54. langchain_core/runnables/graph_mermaid.py +1 -0
  55. langchain_core/runnables/history.py +1 -0
  56. langchain_core/runnables/passthrough.py +3 -0
  57. langchain_core/runnables/retry.py +1 -0
  58. langchain_core/runnables/router.py +1 -0
  59. langchain_core/runnables/schema.py +1 -0
  60. langchain_core/stores.py +3 -0
  61. langchain_core/tools/base.py +43 -11
  62. langchain_core/tools/convert.py +25 -3
  63. langchain_core/tools/retriever.py +8 -1
  64. langchain_core/tools/structured.py +10 -1
  65. langchain_core/tracers/base.py +14 -7
  66. langchain_core/tracers/context.py +1 -1
  67. langchain_core/tracers/core.py +27 -4
  68. langchain_core/tracers/event_stream.py +14 -3
  69. langchain_core/tracers/langchain.py +14 -3
  70. langchain_core/tracers/log_stream.py +4 -1
  71. langchain_core/utils/aiter.py +5 -0
  72. langchain_core/utils/function_calling.py +2 -1
  73. langchain_core/utils/iter.py +1 -0
  74. langchain_core/v1/__init__.py +1 -0
  75. langchain_core/v1/chat_models.py +1047 -0
  76. langchain_core/v1/messages.py +755 -0
  77. langchain_core/vectorstores/base.py +1 -0
  78. langchain_core/version.py +1 -1
  79. {langchain_core-0.3.72.dist-info → langchain_core-0.4.0.dev0.dist-info}/METADATA +1 -1
  80. {langchain_core-0.3.72.dist-info → langchain_core-0.4.0.dev0.dist-info}/RECORD +82 -79
  81. {langchain_core-0.3.72.dist-info → langchain_core-0.4.0.dev0.dist-info}/WHEEL +0 -0
  82. {langchain_core-0.3.72.dist-info → langchain_core-0.4.0.dev0.dist-info}/entry_points.txt +0 -0
@@ -13,6 +13,7 @@ from typing_extensions import override
13
13
 
14
14
  from langchain_core.messages import BaseMessage
15
15
  from langchain_core.output_parsers.transform import BaseTransformOutputParser
16
+ from langchain_core.v1.messages import AIMessage
16
17
 
17
18
  if TYPE_CHECKING:
18
19
  from collections.abc import AsyncIterator, Iterator
@@ -71,7 +72,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
71
72
 
72
73
  @override
73
74
  def _transform(
74
- self, input: Iterator[Union[str, BaseMessage]]
75
+ self, input: Iterator[Union[str, BaseMessage, AIMessage]]
75
76
  ) -> Iterator[list[str]]:
76
77
  buffer = ""
77
78
  for chunk in input:
@@ -81,6 +82,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
81
82
  if not isinstance(chunk_content, str):
82
83
  continue
83
84
  buffer += chunk_content
85
+ elif isinstance(chunk, AIMessage):
86
+ buffer += chunk.text
84
87
  else:
85
88
  # add current chunk to buffer
86
89
  buffer += chunk
@@ -105,7 +108,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
105
108
 
106
109
  @override
107
110
  async def _atransform(
108
- self, input: AsyncIterator[Union[str, BaseMessage]]
111
+ self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
109
112
  ) -> AsyncIterator[list[str]]:
110
113
  buffer = ""
111
114
  async for chunk in input:
@@ -115,6 +118,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
115
118
  if not isinstance(chunk_content, str):
116
119
  continue
117
120
  buffer += chunk_content
121
+ elif isinstance(chunk, AIMessage):
122
+ buffer += chunk.text
118
123
  else:
119
124
  # add current chunk to buffer
120
125
  buffer += chunk
@@ -17,6 +17,7 @@ from langchain_core.output_parsers import (
17
17
  )
18
18
  from langchain_core.output_parsers.json import parse_partial_json
19
19
  from langchain_core.outputs import ChatGeneration, Generation
20
+ from langchain_core.v1.messages import AIMessage
20
21
 
21
22
 
22
23
  class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
@@ -26,7 +27,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
26
27
  """Whether to only return the arguments to the function call."""
27
28
 
28
29
  @override
29
- def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
30
+ def parse_result(
31
+ self, result: Union[list[Generation], AIMessage], *, partial: bool = False
32
+ ) -> Any:
30
33
  """Parse the result of an LLM call to a JSON object.
31
34
 
32
35
  Args:
@@ -39,6 +42,12 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
39
42
  Raises:
40
43
  OutputParserException: If the output is not valid JSON.
41
44
  """
45
+ if isinstance(result, AIMessage):
46
+ msg = (
47
+ "This output parser does not support v1 AIMessages. Use "
48
+ "JsonOutputToolsParser instead."
49
+ )
50
+ raise TypeError(msg)
42
51
  generation = result[0]
43
52
  if not isinstance(generation, ChatGeneration):
44
53
  msg = "This output parser can only be used with a chat generation."
@@ -77,7 +86,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
77
86
  def _diff(self, prev: Optional[Any], next: Any) -> Any:
78
87
  return jsonpatch.make_patch(prev, next).patch
79
88
 
80
- def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
89
+ def parse_result(
90
+ self, result: Union[list[Generation], AIMessage], *, partial: bool = False
91
+ ) -> Any:
81
92
  """Parse the result of an LLM call to a JSON object.
82
93
 
83
94
  Args:
@@ -90,6 +101,12 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
90
101
  Raises:
91
102
  OutputParserException: If the output is not valid JSON.
92
103
  """
104
+ if isinstance(result, AIMessage):
105
+ msg = (
106
+ "This output parser does not support v1 AIMessages. Use "
107
+ "JsonOutputToolsParser instead."
108
+ )
109
+ raise TypeError(msg)
93
110
  if len(result) != 1:
94
111
  msg = f"Expected exactly one result, but got {len(result)}"
95
112
  raise OutputParserException(msg)
@@ -160,7 +177,9 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
160
177
  key_name: str
161
178
  """The name of the key to return."""
162
179
 
163
- def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
180
+ def parse_result(
181
+ self, result: Union[list[Generation], AIMessage], *, partial: bool = False
182
+ ) -> Any:
164
183
  """Parse the result of an LLM call to a JSON object.
165
184
 
166
185
  Args:
@@ -214,6 +233,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
214
233
  pydantic_schema={"cookie": Cookie, "dog": Dog}
215
234
  )
216
235
  result = parser.parse_result([chat_generation])
236
+
217
237
  """
218
238
 
219
239
  pydantic_schema: Union[type[BaseModel], dict[str, type[BaseModel]]]
@@ -253,7 +273,9 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
253
273
  return values
254
274
 
255
275
  @override
256
- def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
276
+ def parse_result(
277
+ self, result: Union[list[Generation], AIMessage], *, partial: bool = False
278
+ ) -> Any:
257
279
  """Parse the result of an LLM call to a JSON object.
258
280
 
259
281
  Args:
@@ -293,7 +315,9 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
293
315
  """The name of the attribute to return."""
294
316
 
295
317
  @override
296
- def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
318
+ def parse_result(
319
+ self, result: Union[list[Generation], AIMessage], *, partial: bool = False
320
+ ) -> Any:
297
321
  """Parse the result of an LLM call to a JSON object.
298
322
 
299
323
  Args:
@@ -4,7 +4,7 @@ import copy
4
4
  import json
5
5
  import logging
6
6
  from json import JSONDecodeError
7
- from typing import Annotated, Any, Optional
7
+ from typing import Annotated, Any, Optional, Union
8
8
 
9
9
  from pydantic import SkipValidation, ValidationError
10
10
 
@@ -16,6 +16,7 @@ from langchain_core.output_parsers.transform import BaseCumulativeTransformOutpu
16
16
  from langchain_core.outputs import ChatGeneration, Generation
17
17
  from langchain_core.utils.json import parse_partial_json
18
18
  from langchain_core.utils.pydantic import TypeBaseModel
19
+ from langchain_core.v1.messages import AIMessage as AIMessageV1
19
20
 
20
21
  logger = logging.getLogger(__name__)
21
22
 
@@ -156,7 +157,9 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
156
157
  If no tool calls are found, None will be returned.
157
158
  """
158
159
 
159
- def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
160
+ def parse_result(
161
+ self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
162
+ ) -> Any:
160
163
  """Parse the result of an LLM call to a list of tool calls.
161
164
 
162
165
  Args:
@@ -173,31 +176,45 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
173
176
  Raises:
174
177
  OutputParserException: If the output is not valid JSON.
175
178
  """
176
- generation = result[0]
177
- if not isinstance(generation, ChatGeneration):
178
- msg = "This output parser can only be used with a chat generation."
179
- raise OutputParserException(msg)
180
- message = generation.message
181
- if isinstance(message, AIMessage) and message.tool_calls:
182
- tool_calls = [dict(tc) for tc in message.tool_calls]
179
+ if isinstance(result, list):
180
+ generation = result[0]
181
+ if not isinstance(generation, ChatGeneration):
182
+ msg = (
183
+ "This output parser can only be used with a chat generation or "
184
+ "v1 AIMessage."
185
+ )
186
+ raise OutputParserException(msg)
187
+ message = generation.message
188
+ if isinstance(message, AIMessage) and message.tool_calls:
189
+ tool_calls = [dict(tc) for tc in message.tool_calls]
190
+ for tool_call in tool_calls:
191
+ if not self.return_id:
192
+ _ = tool_call.pop("id")
193
+ else:
194
+ try:
195
+ raw_tool_calls = copy.deepcopy(
196
+ message.additional_kwargs["tool_calls"]
197
+ )
198
+ except KeyError:
199
+ return []
200
+ tool_calls = parse_tool_calls(
201
+ raw_tool_calls,
202
+ partial=partial,
203
+ strict=self.strict,
204
+ return_id=self.return_id,
205
+ )
206
+ elif result.tool_calls:
207
+ # v1 message
208
+ tool_calls = [dict(tc) for tc in result.tool_calls]
183
209
  for tool_call in tool_calls:
184
210
  if not self.return_id:
185
211
  _ = tool_call.pop("id")
186
212
  else:
187
- try:
188
- raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
189
- except KeyError:
190
- return []
191
- tool_calls = parse_tool_calls(
192
- raw_tool_calls,
193
- partial=partial,
194
- strict=self.strict,
195
- return_id=self.return_id,
196
- )
213
+ return []
214
+
197
215
  # for backwards compatibility
198
216
  for tc in tool_calls:
199
217
  tc["type"] = tc.pop("name")
200
-
201
218
  if self.first_tool_only:
202
219
  return tool_calls[0] if tool_calls else None
203
220
  return tool_calls
@@ -220,7 +237,9 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
220
237
  key_name: str
221
238
  """The type of tools to return."""
222
239
 
223
- def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
240
+ def parse_result(
241
+ self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
242
+ ) -> Any:
224
243
  """Parse the result of an LLM call to a list of tool calls.
225
244
 
226
245
  Args:
@@ -234,32 +253,47 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
234
253
  Returns:
235
254
  The parsed tool calls.
236
255
  """
237
- generation = result[0]
238
- if not isinstance(generation, ChatGeneration):
239
- msg = "This output parser can only be used with a chat generation."
240
- raise OutputParserException(msg)
241
- message = generation.message
242
- if isinstance(message, AIMessage) and message.tool_calls:
243
- parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
256
+ if isinstance(result, list):
257
+ generation = result[0]
258
+ if not isinstance(generation, ChatGeneration):
259
+ msg = "This output parser can only be used with a chat generation."
260
+ raise OutputParserException(msg)
261
+ message = generation.message
262
+ if isinstance(message, AIMessage) and message.tool_calls:
263
+ parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
264
+ for tool_call in parsed_tool_calls:
265
+ if not self.return_id:
266
+ _ = tool_call.pop("id")
267
+ else:
268
+ try:
269
+ raw_tool_calls = copy.deepcopy(
270
+ message.additional_kwargs["tool_calls"]
271
+ )
272
+ except KeyError:
273
+ if self.first_tool_only:
274
+ return None
275
+ return []
276
+ parsed_tool_calls = parse_tool_calls(
277
+ raw_tool_calls,
278
+ partial=partial,
279
+ strict=self.strict,
280
+ return_id=self.return_id,
281
+ )
282
+ elif result.tool_calls:
283
+ # v1 message
284
+ parsed_tool_calls = [dict(tc) for tc in result.tool_calls]
244
285
  for tool_call in parsed_tool_calls:
245
286
  if not self.return_id:
246
287
  _ = tool_call.pop("id")
247
288
  else:
248
- try:
249
- raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
250
- except KeyError:
251
- if self.first_tool_only:
252
- return None
253
- return []
254
- parsed_tool_calls = parse_tool_calls(
255
- raw_tool_calls,
256
- partial=partial,
257
- strict=self.strict,
258
- return_id=self.return_id,
259
- )
289
+ if self.first_tool_only:
290
+ return None
291
+ return []
292
+
260
293
  # For backwards compatibility
261
294
  for tc in parsed_tool_calls:
262
295
  tc["type"] = tc.pop("name")
296
+
263
297
  if self.first_tool_only:
264
298
  parsed_result = list(
265
299
  filter(lambda x: x["type"] == self.key_name, parsed_tool_calls)
@@ -299,7 +333,9 @@ class PydanticToolsParser(JsonOutputToolsParser):
299
333
 
300
334
  # TODO: Support more granular streaming of objects. Currently only streams once all
301
335
  # Pydantic object fields are present.
302
- def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
336
+ def parse_result(
337
+ self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
338
+ ) -> Any:
303
339
  """Parse the result of an LLM call to a list of Pydantic objects.
304
340
 
305
341
  Args:
@@ -337,12 +373,19 @@ class PydanticToolsParser(JsonOutputToolsParser):
337
373
  except (ValidationError, ValueError):
338
374
  if partial:
339
375
  continue
340
- has_max_tokens_stop_reason = any(
341
- generation.message.response_metadata.get("stop_reason")
342
- == "max_tokens"
343
- for generation in result
344
- if isinstance(generation, ChatGeneration)
345
- )
376
+ has_max_tokens_stop_reason = False
377
+ if isinstance(result, list):
378
+ has_max_tokens_stop_reason = any(
379
+ generation.message.response_metadata.get("stop_reason")
380
+ == "max_tokens"
381
+ for generation in result
382
+ if isinstance(generation, ChatGeneration)
383
+ )
384
+ else:
385
+ # v1 message
386
+ has_max_tokens_stop_reason = (
387
+ result.response_metadata.get("stop_reason") == "max_tokens"
388
+ )
346
389
  if has_max_tokens_stop_reason:
347
390
  logger.exception(_MAX_TOKENS_ERROR)
348
391
  raise
@@ -1,7 +1,7 @@
1
1
  """Output parsers using Pydantic."""
2
2
 
3
3
  import json
4
- from typing import Annotated, Generic, Optional
4
+ from typing import Annotated, Generic, Optional, Union
5
5
 
6
6
  import pydantic
7
7
  from pydantic import SkipValidation
@@ -14,6 +14,7 @@ from langchain_core.utils.pydantic import (
14
14
  PydanticBaseModel,
15
15
  TBaseModel,
16
16
  )
17
+ from langchain_core.v1.messages import AIMessage
17
18
 
18
19
 
19
20
  class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
@@ -43,7 +44,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
43
44
  return OutputParserException(msg, llm_output=json_string)
44
45
 
45
46
  def parse_result(
46
- self, result: list[Generation], *, partial: bool = False
47
+ self, result: Union[list[Generation], AIMessage], *, partial: bool = False
47
48
  ) -> Optional[TBaseModel]:
48
49
  """Parse the result of an LLM call to a pydantic object.
49
50
 
@@ -20,6 +20,7 @@ from langchain_core.outputs import (
20
20
  GenerationChunk,
21
21
  )
22
22
  from langchain_core.runnables.config import run_in_executor
23
+ from langchain_core.v1.messages import AIMessage, AIMessageChunk
23
24
 
24
25
  if TYPE_CHECKING:
25
26
  from collections.abc import AsyncIterator, Iterator
@@ -32,23 +33,27 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
32
33
 
33
34
  def _transform(
34
35
  self,
35
- input: Iterator[Union[str, BaseMessage]], # noqa: A002
36
+ input: Iterator[Union[str, BaseMessage, AIMessage]],
36
37
  ) -> Iterator[T]:
37
38
  for chunk in input:
38
39
  if isinstance(chunk, BaseMessage):
39
40
  yield self.parse_result([ChatGeneration(message=chunk)])
41
+ elif isinstance(chunk, AIMessage):
42
+ yield self.parse_result(chunk)
40
43
  else:
41
44
  yield self.parse_result([Generation(text=chunk)])
42
45
 
43
46
  async def _atransform(
44
47
  self,
45
- input: AsyncIterator[Union[str, BaseMessage]], # noqa: A002
48
+ input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
46
49
  ) -> AsyncIterator[T]:
47
50
  async for chunk in input:
48
51
  if isinstance(chunk, BaseMessage):
49
52
  yield await run_in_executor(
50
53
  None, self.parse_result, [ChatGeneration(message=chunk)]
51
54
  )
55
+ elif isinstance(chunk, AIMessage):
56
+ yield await run_in_executor(None, self.parse_result, chunk)
52
57
  else:
53
58
  yield await run_in_executor(
54
59
  None, self.parse_result, [Generation(text=chunk)]
@@ -57,7 +62,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
57
62
  @override
58
63
  def transform(
59
64
  self,
60
- input: Iterator[Union[str, BaseMessage]],
65
+ input: Iterator[Union[str, BaseMessage, AIMessage]],
61
66
  config: Optional[RunnableConfig] = None,
62
67
  **kwargs: Any,
63
68
  ) -> Iterator[T]:
@@ -78,7 +83,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
78
83
  @override
79
84
  async def atransform(
80
85
  self,
81
- input: AsyncIterator[Union[str, BaseMessage]],
86
+ input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
82
87
  config: Optional[RunnableConfig] = None,
83
88
  **kwargs: Any,
84
89
  ) -> AsyncIterator[T]:
@@ -125,23 +130,42 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
125
130
  raise NotImplementedError
126
131
 
127
132
  @override
128
- def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
133
+ def _transform(
134
+ self, input: Iterator[Union[str, BaseMessage, AIMessage]]
135
+ ) -> Iterator[Any]:
129
136
  prev_parsed = None
130
- acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
137
+ acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
138
+ None
139
+ )
131
140
  for chunk in input:
132
- chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
141
+ chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
133
142
  if isinstance(chunk, BaseMessageChunk):
134
143
  chunk_gen = ChatGenerationChunk(message=chunk)
135
144
  elif isinstance(chunk, BaseMessage):
136
145
  chunk_gen = ChatGenerationChunk(
137
146
  message=BaseMessageChunk(**chunk.model_dump())
138
147
  )
148
+ elif isinstance(chunk, AIMessageChunk):
149
+ chunk_gen = chunk
150
+ elif isinstance(chunk, AIMessage):
151
+ chunk_gen = AIMessageChunk(
152
+ content=chunk.content,
153
+ id=chunk.id,
154
+ name=chunk.name,
155
+ lc_version=chunk.lc_version,
156
+ response_metadata=chunk.response_metadata,
157
+ usage_metadata=chunk.usage_metadata,
158
+ parsed=chunk.parsed,
159
+ )
139
160
  else:
140
161
  chunk_gen = GenerationChunk(text=chunk)
141
162
 
142
163
  acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
143
164
 
144
- parsed = self.parse_result([acc_gen], partial=True)
165
+ if isinstance(acc_gen, AIMessageChunk):
166
+ parsed = self.parse_result(acc_gen, partial=True)
167
+ else:
168
+ parsed = self.parse_result([acc_gen], partial=True)
145
169
  if parsed is not None and parsed != prev_parsed:
146
170
  if self.diff:
147
171
  yield self._diff(prev_parsed, parsed)
@@ -151,24 +175,41 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
151
175
 
152
176
  @override
153
177
  async def _atransform(
154
- self, input: AsyncIterator[Union[str, BaseMessage]]
178
+ self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
155
179
  ) -> AsyncIterator[T]:
156
180
  prev_parsed = None
157
- acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
181
+ acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
182
+ None
183
+ )
158
184
  async for chunk in input:
159
- chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
185
+ chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
160
186
  if isinstance(chunk, BaseMessageChunk):
161
187
  chunk_gen = ChatGenerationChunk(message=chunk)
162
188
  elif isinstance(chunk, BaseMessage):
163
189
  chunk_gen = ChatGenerationChunk(
164
190
  message=BaseMessageChunk(**chunk.model_dump())
165
191
  )
192
+ elif isinstance(chunk, AIMessageChunk):
193
+ chunk_gen = chunk
194
+ elif isinstance(chunk, AIMessage):
195
+ chunk_gen = AIMessageChunk(
196
+ content=chunk.content,
197
+ id=chunk.id,
198
+ name=chunk.name,
199
+ lc_version=chunk.lc_version,
200
+ response_metadata=chunk.response_metadata,
201
+ usage_metadata=chunk.usage_metadata,
202
+ parsed=chunk.parsed,
203
+ )
166
204
  else:
167
205
  chunk_gen = GenerationChunk(text=chunk)
168
206
 
169
207
  acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
170
208
 
171
- parsed = await self.aparse_result([acc_gen], partial=True)
209
+ if isinstance(acc_gen, AIMessageChunk):
210
+ parsed = await self.aparse_result(acc_gen, partial=True)
211
+ else:
212
+ parsed = await self.aparse_result([acc_gen], partial=True)
172
213
  if parsed is not None and parsed != prev_parsed:
173
214
  if self.diff:
174
215
  yield await run_in_executor(None, self._diff, prev_parsed, parsed)
@@ -12,8 +12,10 @@ from typing_extensions import override
12
12
 
13
13
  from langchain_core.exceptions import OutputParserException
14
14
  from langchain_core.messages import BaseMessage
15
+ from langchain_core.messages.utils import convert_from_v1_message
15
16
  from langchain_core.output_parsers.transform import BaseTransformOutputParser
16
17
  from langchain_core.runnables.utils import AddableDict
18
+ from langchain_core.v1.messages import AIMessage
17
19
 
18
20
  XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file.
19
21
  1. Output should conform to the tags below.
@@ -240,21 +242,28 @@ class XMLOutputParser(BaseTransformOutputParser):
240
242
 
241
243
  @override
242
244
  def _transform(
243
- self, input: Iterator[Union[str, BaseMessage]]
245
+ self, input: Iterator[Union[str, BaseMessage, AIMessage]]
244
246
  ) -> Iterator[AddableDict]:
245
247
  streaming_parser = _StreamingParser(self.parser)
246
248
  for chunk in input:
247
- yield from streaming_parser.parse(chunk)
249
+ if isinstance(chunk, AIMessage):
250
+ yield from streaming_parser.parse(convert_from_v1_message(chunk))
251
+ else:
252
+ yield from streaming_parser.parse(chunk)
248
253
  streaming_parser.close()
249
254
 
250
255
  @override
251
256
  async def _atransform(
252
- self, input: AsyncIterator[Union[str, BaseMessage]]
257
+ self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
253
258
  ) -> AsyncIterator[AddableDict]:
254
259
  streaming_parser = _StreamingParser(self.parser)
255
260
  async for chunk in input:
256
- for output in streaming_parser.parse(chunk):
257
- yield output
261
+ if isinstance(chunk, AIMessage):
262
+ for output in streaming_parser.parse(convert_from_v1_message(chunk)):
263
+ yield output
264
+ else:
265
+ for output in streaming_parser.parse(chunk):
266
+ yield output
258
267
  streaming_parser.close()
259
268
 
260
269
  def _root_to_dict(self, root: ET.Element) -> dict[str, Union[str, list[Any]]]:
@@ -46,7 +46,10 @@ class LLMResult(BaseModel):
46
46
  relevant information from standardized fields present in AIMessage.
47
47
  """
48
48
  run: Optional[list[RunInfo]] = None
49
- """List of metadata info for model call for each input."""
49
+ """List of metadata info for model call for each input.
50
+
51
+ See :class:`~langchain_core.outputs.run_info.RunInfo` for details.
52
+ """
50
53
 
51
54
  type: Literal["LLMResult"] = "LLMResult"
52
55
  """Type is used exclusively for serialization purposes."""