langchain-core 0.3.71__py3-none-any.whl → 0.4.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_core/__init__.py +1 -1
- langchain_core/_api/beta_decorator.py +1 -0
- langchain_core/_api/deprecation.py +2 -0
- langchain_core/beta/runnables/context.py +1 -0
- langchain_core/callbacks/base.py +23 -14
- langchain_core/callbacks/file.py +1 -0
- langchain_core/callbacks/manager.py +145 -19
- langchain_core/callbacks/streaming_stdout.py +4 -3
- langchain_core/callbacks/usage.py +15 -3
- langchain_core/chat_history.py +1 -0
- langchain_core/document_loaders/langsmith.py +2 -1
- langchain_core/documents/base.py +2 -0
- langchain_core/embeddings/fake.py +2 -0
- langchain_core/indexing/api.py +10 -0
- langchain_core/language_models/_utils.py +37 -0
- langchain_core/language_models/base.py +4 -1
- langchain_core/language_models/chat_models.py +48 -27
- langchain_core/language_models/fake_chat_models.py +71 -1
- langchain_core/language_models/llms.py +1 -0
- langchain_core/memory.py +1 -0
- langchain_core/messages/__init__.py +54 -0
- langchain_core/messages/ai.py +31 -18
- langchain_core/messages/content_blocks.py +1349 -69
- langchain_core/messages/human.py +1 -0
- langchain_core/messages/modifier.py +1 -1
- langchain_core/messages/tool.py +8 -83
- langchain_core/messages/utils.py +221 -6
- langchain_core/output_parsers/base.py +51 -14
- langchain_core/output_parsers/json.py +5 -2
- langchain_core/output_parsers/list.py +7 -2
- langchain_core/output_parsers/openai_functions.py +29 -5
- langchain_core/output_parsers/openai_tools.py +90 -47
- langchain_core/output_parsers/pydantic.py +3 -2
- langchain_core/output_parsers/transform.py +53 -12
- langchain_core/output_parsers/xml.py +14 -5
- langchain_core/outputs/llm_result.py +4 -1
- langchain_core/prompt_values.py +111 -7
- langchain_core/prompts/base.py +4 -0
- langchain_core/prompts/chat.py +3 -0
- langchain_core/prompts/few_shot.py +1 -0
- langchain_core/prompts/few_shot_with_templates.py +1 -0
- langchain_core/prompts/image.py +1 -0
- langchain_core/prompts/pipeline.py +1 -0
- langchain_core/prompts/prompt.py +1 -0
- langchain_core/prompts/structured.py +1 -0
- langchain_core/rate_limiters.py +1 -0
- langchain_core/retrievers.py +3 -0
- langchain_core/runnables/base.py +75 -57
- langchain_core/runnables/branch.py +1 -0
- langchain_core/runnables/config.py +2 -2
- langchain_core/runnables/configurable.py +2 -1
- langchain_core/runnables/fallbacks.py +3 -7
- langchain_core/runnables/graph.py +5 -3
- langchain_core/runnables/graph_ascii.py +1 -0
- langchain_core/runnables/graph_mermaid.py +1 -0
- langchain_core/runnables/history.py +1 -0
- langchain_core/runnables/passthrough.py +3 -0
- langchain_core/runnables/retry.py +1 -0
- langchain_core/runnables/router.py +1 -0
- langchain_core/runnables/schema.py +1 -0
- langchain_core/stores.py +3 -0
- langchain_core/tools/base.py +43 -11
- langchain_core/tools/convert.py +25 -3
- langchain_core/tools/retriever.py +8 -1
- langchain_core/tools/structured.py +10 -1
- langchain_core/tracers/base.py +14 -7
- langchain_core/tracers/context.py +1 -1
- langchain_core/tracers/core.py +27 -4
- langchain_core/tracers/event_stream.py +14 -3
- langchain_core/tracers/langchain.py +14 -3
- langchain_core/tracers/log_stream.py +4 -1
- langchain_core/utils/aiter.py +5 -0
- langchain_core/utils/function_calling.py +2 -1
- langchain_core/utils/iter.py +1 -0
- langchain_core/utils/json_schema.py +1 -1
- langchain_core/v1/__init__.py +1 -0
- langchain_core/v1/chat_models.py +1047 -0
- langchain_core/v1/messages.py +755 -0
- langchain_core/vectorstores/base.py +1 -0
- langchain_core/version.py +1 -1
- {langchain_core-0.3.71.dist-info → langchain_core-0.4.0.dev0.dist-info}/METADATA +1 -1
- {langchain_core-0.3.71.dist-info → langchain_core-0.4.0.dev0.dist-info}/RECORD +84 -81
- {langchain_core-0.3.71.dist-info → langchain_core-0.4.0.dev0.dist-info}/WHEEL +0 -0
- {langchain_core-0.3.71.dist-info → langchain_core-0.4.0.dev0.dist-info}/entry_points.txt +0 -0
|
@@ -13,6 +13,7 @@ from typing_extensions import override
|
|
|
13
13
|
|
|
14
14
|
from langchain_core.messages import BaseMessage
|
|
15
15
|
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
|
16
|
+
from langchain_core.v1.messages import AIMessage
|
|
16
17
|
|
|
17
18
|
if TYPE_CHECKING:
|
|
18
19
|
from collections.abc import AsyncIterator, Iterator
|
|
@@ -71,7 +72,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
|
|
71
72
|
|
|
72
73
|
@override
|
|
73
74
|
def _transform(
|
|
74
|
-
self, input: Iterator[Union[str, BaseMessage]]
|
|
75
|
+
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
|
|
75
76
|
) -> Iterator[list[str]]:
|
|
76
77
|
buffer = ""
|
|
77
78
|
for chunk in input:
|
|
@@ -81,6 +82,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
|
|
81
82
|
if not isinstance(chunk_content, str):
|
|
82
83
|
continue
|
|
83
84
|
buffer += chunk_content
|
|
85
|
+
elif isinstance(chunk, AIMessage):
|
|
86
|
+
buffer += chunk.text
|
|
84
87
|
else:
|
|
85
88
|
# add current chunk to buffer
|
|
86
89
|
buffer += chunk
|
|
@@ -105,7 +108,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
|
|
105
108
|
|
|
106
109
|
@override
|
|
107
110
|
async def _atransform(
|
|
108
|
-
self, input: AsyncIterator[Union[str, BaseMessage]]
|
|
111
|
+
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
|
|
109
112
|
) -> AsyncIterator[list[str]]:
|
|
110
113
|
buffer = ""
|
|
111
114
|
async for chunk in input:
|
|
@@ -115,6 +118,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
|
|
115
118
|
if not isinstance(chunk_content, str):
|
|
116
119
|
continue
|
|
117
120
|
buffer += chunk_content
|
|
121
|
+
elif isinstance(chunk, AIMessage):
|
|
122
|
+
buffer += chunk.text
|
|
118
123
|
else:
|
|
119
124
|
# add current chunk to buffer
|
|
120
125
|
buffer += chunk
|
|
@@ -17,6 +17,7 @@ from langchain_core.output_parsers import (
|
|
|
17
17
|
)
|
|
18
18
|
from langchain_core.output_parsers.json import parse_partial_json
|
|
19
19
|
from langchain_core.outputs import ChatGeneration, Generation
|
|
20
|
+
from langchain_core.v1.messages import AIMessage
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
|
@@ -26,7 +27,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
|
|
26
27
|
"""Whether to only return the arguments to the function call."""
|
|
27
28
|
|
|
28
29
|
@override
|
|
29
|
-
def parse_result(
|
|
30
|
+
def parse_result(
|
|
31
|
+
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
|
32
|
+
) -> Any:
|
|
30
33
|
"""Parse the result of an LLM call to a JSON object.
|
|
31
34
|
|
|
32
35
|
Args:
|
|
@@ -39,6 +42,12 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
|
|
39
42
|
Raises:
|
|
40
43
|
OutputParserException: If the output is not valid JSON.
|
|
41
44
|
"""
|
|
45
|
+
if isinstance(result, AIMessage):
|
|
46
|
+
msg = (
|
|
47
|
+
"This output parser does not support v1 AIMessages. Use "
|
|
48
|
+
"JsonOutputToolsParser instead."
|
|
49
|
+
)
|
|
50
|
+
raise TypeError(msg)
|
|
42
51
|
generation = result[0]
|
|
43
52
|
if not isinstance(generation, ChatGeneration):
|
|
44
53
|
msg = "This output parser can only be used with a chat generation."
|
|
@@ -77,7 +86,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|
|
77
86
|
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
|
78
87
|
return jsonpatch.make_patch(prev, next).patch
|
|
79
88
|
|
|
80
|
-
def parse_result(
|
|
89
|
+
def parse_result(
|
|
90
|
+
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
|
91
|
+
) -> Any:
|
|
81
92
|
"""Parse the result of an LLM call to a JSON object.
|
|
82
93
|
|
|
83
94
|
Args:
|
|
@@ -90,6 +101,12 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|
|
90
101
|
Raises:
|
|
91
102
|
OutputParserException: If the output is not valid JSON.
|
|
92
103
|
"""
|
|
104
|
+
if isinstance(result, AIMessage):
|
|
105
|
+
msg = (
|
|
106
|
+
"This output parser does not support v1 AIMessages. Use "
|
|
107
|
+
"JsonOutputToolsParser instead."
|
|
108
|
+
)
|
|
109
|
+
raise TypeError(msg)
|
|
93
110
|
if len(result) != 1:
|
|
94
111
|
msg = f"Expected exactly one result, but got {len(result)}"
|
|
95
112
|
raise OutputParserException(msg)
|
|
@@ -160,7 +177,9 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
|
|
160
177
|
key_name: str
|
|
161
178
|
"""The name of the key to return."""
|
|
162
179
|
|
|
163
|
-
def parse_result(
|
|
180
|
+
def parse_result(
|
|
181
|
+
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
|
182
|
+
) -> Any:
|
|
164
183
|
"""Parse the result of an LLM call to a JSON object.
|
|
165
184
|
|
|
166
185
|
Args:
|
|
@@ -214,6 +233,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
|
|
214
233
|
pydantic_schema={"cookie": Cookie, "dog": Dog}
|
|
215
234
|
)
|
|
216
235
|
result = parser.parse_result([chat_generation])
|
|
236
|
+
|
|
217
237
|
"""
|
|
218
238
|
|
|
219
239
|
pydantic_schema: Union[type[BaseModel], dict[str, type[BaseModel]]]
|
|
@@ -253,7 +273,9 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
|
|
253
273
|
return values
|
|
254
274
|
|
|
255
275
|
@override
|
|
256
|
-
def parse_result(
|
|
276
|
+
def parse_result(
|
|
277
|
+
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
|
278
|
+
) -> Any:
|
|
257
279
|
"""Parse the result of an LLM call to a JSON object.
|
|
258
280
|
|
|
259
281
|
Args:
|
|
@@ -293,7 +315,9 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
|
|
293
315
|
"""The name of the attribute to return."""
|
|
294
316
|
|
|
295
317
|
@override
|
|
296
|
-
def parse_result(
|
|
318
|
+
def parse_result(
|
|
319
|
+
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
|
320
|
+
) -> Any:
|
|
297
321
|
"""Parse the result of an LLM call to a JSON object.
|
|
298
322
|
|
|
299
323
|
Args:
|
|
@@ -4,7 +4,7 @@ import copy
|
|
|
4
4
|
import json
|
|
5
5
|
import logging
|
|
6
6
|
from json import JSONDecodeError
|
|
7
|
-
from typing import Annotated, Any, Optional
|
|
7
|
+
from typing import Annotated, Any, Optional, Union
|
|
8
8
|
|
|
9
9
|
from pydantic import SkipValidation, ValidationError
|
|
10
10
|
|
|
@@ -16,6 +16,7 @@ from langchain_core.output_parsers.transform import BaseCumulativeTransformOutpu
|
|
|
16
16
|
from langchain_core.outputs import ChatGeneration, Generation
|
|
17
17
|
from langchain_core.utils.json import parse_partial_json
|
|
18
18
|
from langchain_core.utils.pydantic import TypeBaseModel
|
|
19
|
+
from langchain_core.v1.messages import AIMessage as AIMessageV1
|
|
19
20
|
|
|
20
21
|
logger = logging.getLogger(__name__)
|
|
21
22
|
|
|
@@ -156,7 +157,9 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
|
|
156
157
|
If no tool calls are found, None will be returned.
|
|
157
158
|
"""
|
|
158
159
|
|
|
159
|
-
def parse_result(
|
|
160
|
+
def parse_result(
|
|
161
|
+
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
|
162
|
+
) -> Any:
|
|
160
163
|
"""Parse the result of an LLM call to a list of tool calls.
|
|
161
164
|
|
|
162
165
|
Args:
|
|
@@ -173,31 +176,45 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
|
|
173
176
|
Raises:
|
|
174
177
|
OutputParserException: If the output is not valid JSON.
|
|
175
178
|
"""
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
179
|
+
if isinstance(result, list):
|
|
180
|
+
generation = result[0]
|
|
181
|
+
if not isinstance(generation, ChatGeneration):
|
|
182
|
+
msg = (
|
|
183
|
+
"This output parser can only be used with a chat generation or "
|
|
184
|
+
"v1 AIMessage."
|
|
185
|
+
)
|
|
186
|
+
raise OutputParserException(msg)
|
|
187
|
+
message = generation.message
|
|
188
|
+
if isinstance(message, AIMessage) and message.tool_calls:
|
|
189
|
+
tool_calls = [dict(tc) for tc in message.tool_calls]
|
|
190
|
+
for tool_call in tool_calls:
|
|
191
|
+
if not self.return_id:
|
|
192
|
+
_ = tool_call.pop("id")
|
|
193
|
+
else:
|
|
194
|
+
try:
|
|
195
|
+
raw_tool_calls = copy.deepcopy(
|
|
196
|
+
message.additional_kwargs["tool_calls"]
|
|
197
|
+
)
|
|
198
|
+
except KeyError:
|
|
199
|
+
return []
|
|
200
|
+
tool_calls = parse_tool_calls(
|
|
201
|
+
raw_tool_calls,
|
|
202
|
+
partial=partial,
|
|
203
|
+
strict=self.strict,
|
|
204
|
+
return_id=self.return_id,
|
|
205
|
+
)
|
|
206
|
+
elif result.tool_calls:
|
|
207
|
+
# v1 message
|
|
208
|
+
tool_calls = [dict(tc) for tc in result.tool_calls]
|
|
183
209
|
for tool_call in tool_calls:
|
|
184
210
|
if not self.return_id:
|
|
185
211
|
_ = tool_call.pop("id")
|
|
186
212
|
else:
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
except KeyError:
|
|
190
|
-
return []
|
|
191
|
-
tool_calls = parse_tool_calls(
|
|
192
|
-
raw_tool_calls,
|
|
193
|
-
partial=partial,
|
|
194
|
-
strict=self.strict,
|
|
195
|
-
return_id=self.return_id,
|
|
196
|
-
)
|
|
213
|
+
return []
|
|
214
|
+
|
|
197
215
|
# for backwards compatibility
|
|
198
216
|
for tc in tool_calls:
|
|
199
217
|
tc["type"] = tc.pop("name")
|
|
200
|
-
|
|
201
218
|
if self.first_tool_only:
|
|
202
219
|
return tool_calls[0] if tool_calls else None
|
|
203
220
|
return tool_calls
|
|
@@ -220,7 +237,9 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
|
|
220
237
|
key_name: str
|
|
221
238
|
"""The type of tools to return."""
|
|
222
239
|
|
|
223
|
-
def parse_result(
|
|
240
|
+
def parse_result(
|
|
241
|
+
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
|
242
|
+
) -> Any:
|
|
224
243
|
"""Parse the result of an LLM call to a list of tool calls.
|
|
225
244
|
|
|
226
245
|
Args:
|
|
@@ -234,32 +253,47 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
|
|
234
253
|
Returns:
|
|
235
254
|
The parsed tool calls.
|
|
236
255
|
"""
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
256
|
+
if isinstance(result, list):
|
|
257
|
+
generation = result[0]
|
|
258
|
+
if not isinstance(generation, ChatGeneration):
|
|
259
|
+
msg = "This output parser can only be used with a chat generation."
|
|
260
|
+
raise OutputParserException(msg)
|
|
261
|
+
message = generation.message
|
|
262
|
+
if isinstance(message, AIMessage) and message.tool_calls:
|
|
263
|
+
parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
|
|
264
|
+
for tool_call in parsed_tool_calls:
|
|
265
|
+
if not self.return_id:
|
|
266
|
+
_ = tool_call.pop("id")
|
|
267
|
+
else:
|
|
268
|
+
try:
|
|
269
|
+
raw_tool_calls = copy.deepcopy(
|
|
270
|
+
message.additional_kwargs["tool_calls"]
|
|
271
|
+
)
|
|
272
|
+
except KeyError:
|
|
273
|
+
if self.first_tool_only:
|
|
274
|
+
return None
|
|
275
|
+
return []
|
|
276
|
+
parsed_tool_calls = parse_tool_calls(
|
|
277
|
+
raw_tool_calls,
|
|
278
|
+
partial=partial,
|
|
279
|
+
strict=self.strict,
|
|
280
|
+
return_id=self.return_id,
|
|
281
|
+
)
|
|
282
|
+
elif result.tool_calls:
|
|
283
|
+
# v1 message
|
|
284
|
+
parsed_tool_calls = [dict(tc) for tc in result.tool_calls]
|
|
244
285
|
for tool_call in parsed_tool_calls:
|
|
245
286
|
if not self.return_id:
|
|
246
287
|
_ = tool_call.pop("id")
|
|
247
288
|
else:
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
return None
|
|
253
|
-
return []
|
|
254
|
-
parsed_tool_calls = parse_tool_calls(
|
|
255
|
-
raw_tool_calls,
|
|
256
|
-
partial=partial,
|
|
257
|
-
strict=self.strict,
|
|
258
|
-
return_id=self.return_id,
|
|
259
|
-
)
|
|
289
|
+
if self.first_tool_only:
|
|
290
|
+
return None
|
|
291
|
+
return []
|
|
292
|
+
|
|
260
293
|
# For backwards compatibility
|
|
261
294
|
for tc in parsed_tool_calls:
|
|
262
295
|
tc["type"] = tc.pop("name")
|
|
296
|
+
|
|
263
297
|
if self.first_tool_only:
|
|
264
298
|
parsed_result = list(
|
|
265
299
|
filter(lambda x: x["type"] == self.key_name, parsed_tool_calls)
|
|
@@ -299,7 +333,9 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
|
|
299
333
|
|
|
300
334
|
# TODO: Support more granular streaming of objects. Currently only streams once all
|
|
301
335
|
# Pydantic object fields are present.
|
|
302
|
-
def parse_result(
|
|
336
|
+
def parse_result(
|
|
337
|
+
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
|
338
|
+
) -> Any:
|
|
303
339
|
"""Parse the result of an LLM call to a list of Pydantic objects.
|
|
304
340
|
|
|
305
341
|
Args:
|
|
@@ -337,12 +373,19 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
|
|
337
373
|
except (ValidationError, ValueError):
|
|
338
374
|
if partial:
|
|
339
375
|
continue
|
|
340
|
-
has_max_tokens_stop_reason =
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
376
|
+
has_max_tokens_stop_reason = False
|
|
377
|
+
if isinstance(result, list):
|
|
378
|
+
has_max_tokens_stop_reason = any(
|
|
379
|
+
generation.message.response_metadata.get("stop_reason")
|
|
380
|
+
== "max_tokens"
|
|
381
|
+
for generation in result
|
|
382
|
+
if isinstance(generation, ChatGeneration)
|
|
383
|
+
)
|
|
384
|
+
else:
|
|
385
|
+
# v1 message
|
|
386
|
+
has_max_tokens_stop_reason = (
|
|
387
|
+
result.response_metadata.get("stop_reason") == "max_tokens"
|
|
388
|
+
)
|
|
346
389
|
if has_max_tokens_stop_reason:
|
|
347
390
|
logger.exception(_MAX_TOKENS_ERROR)
|
|
348
391
|
raise
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Output parsers using Pydantic."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
from typing import Annotated, Generic, Optional
|
|
4
|
+
from typing import Annotated, Generic, Optional, Union
|
|
5
5
|
|
|
6
6
|
import pydantic
|
|
7
7
|
from pydantic import SkipValidation
|
|
@@ -14,6 +14,7 @@ from langchain_core.utils.pydantic import (
|
|
|
14
14
|
PydanticBaseModel,
|
|
15
15
|
TBaseModel,
|
|
16
16
|
)
|
|
17
|
+
from langchain_core.v1.messages import AIMessage
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
|
@@ -43,7 +44,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
|
|
43
44
|
return OutputParserException(msg, llm_output=json_string)
|
|
44
45
|
|
|
45
46
|
def parse_result(
|
|
46
|
-
self, result: list[Generation], *, partial: bool = False
|
|
47
|
+
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
|
47
48
|
) -> Optional[TBaseModel]:
|
|
48
49
|
"""Parse the result of an LLM call to a pydantic object.
|
|
49
50
|
|
|
@@ -20,6 +20,7 @@ from langchain_core.outputs import (
|
|
|
20
20
|
GenerationChunk,
|
|
21
21
|
)
|
|
22
22
|
from langchain_core.runnables.config import run_in_executor
|
|
23
|
+
from langchain_core.v1.messages import AIMessage, AIMessageChunk
|
|
23
24
|
|
|
24
25
|
if TYPE_CHECKING:
|
|
25
26
|
from collections.abc import AsyncIterator, Iterator
|
|
@@ -32,23 +33,27 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
|
|
32
33
|
|
|
33
34
|
def _transform(
|
|
34
35
|
self,
|
|
35
|
-
input: Iterator[Union[str, BaseMessage]],
|
|
36
|
+
input: Iterator[Union[str, BaseMessage, AIMessage]],
|
|
36
37
|
) -> Iterator[T]:
|
|
37
38
|
for chunk in input:
|
|
38
39
|
if isinstance(chunk, BaseMessage):
|
|
39
40
|
yield self.parse_result([ChatGeneration(message=chunk)])
|
|
41
|
+
elif isinstance(chunk, AIMessage):
|
|
42
|
+
yield self.parse_result(chunk)
|
|
40
43
|
else:
|
|
41
44
|
yield self.parse_result([Generation(text=chunk)])
|
|
42
45
|
|
|
43
46
|
async def _atransform(
|
|
44
47
|
self,
|
|
45
|
-
input: AsyncIterator[Union[str, BaseMessage]],
|
|
48
|
+
input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
|
|
46
49
|
) -> AsyncIterator[T]:
|
|
47
50
|
async for chunk in input:
|
|
48
51
|
if isinstance(chunk, BaseMessage):
|
|
49
52
|
yield await run_in_executor(
|
|
50
53
|
None, self.parse_result, [ChatGeneration(message=chunk)]
|
|
51
54
|
)
|
|
55
|
+
elif isinstance(chunk, AIMessage):
|
|
56
|
+
yield await run_in_executor(None, self.parse_result, chunk)
|
|
52
57
|
else:
|
|
53
58
|
yield await run_in_executor(
|
|
54
59
|
None, self.parse_result, [Generation(text=chunk)]
|
|
@@ -57,7 +62,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
|
|
57
62
|
@override
|
|
58
63
|
def transform(
|
|
59
64
|
self,
|
|
60
|
-
input: Iterator[Union[str, BaseMessage]],
|
|
65
|
+
input: Iterator[Union[str, BaseMessage, AIMessage]],
|
|
61
66
|
config: Optional[RunnableConfig] = None,
|
|
62
67
|
**kwargs: Any,
|
|
63
68
|
) -> Iterator[T]:
|
|
@@ -78,7 +83,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
|
|
78
83
|
@override
|
|
79
84
|
async def atransform(
|
|
80
85
|
self,
|
|
81
|
-
input: AsyncIterator[Union[str, BaseMessage]],
|
|
86
|
+
input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
|
|
82
87
|
config: Optional[RunnableConfig] = None,
|
|
83
88
|
**kwargs: Any,
|
|
84
89
|
) -> AsyncIterator[T]:
|
|
@@ -125,23 +130,42 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|
|
125
130
|
raise NotImplementedError
|
|
126
131
|
|
|
127
132
|
@override
|
|
128
|
-
def _transform(
|
|
133
|
+
def _transform(
|
|
134
|
+
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
|
|
135
|
+
) -> Iterator[Any]:
|
|
129
136
|
prev_parsed = None
|
|
130
|
-
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] =
|
|
137
|
+
acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
|
|
138
|
+
None
|
|
139
|
+
)
|
|
131
140
|
for chunk in input:
|
|
132
|
-
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
|
|
141
|
+
chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
|
133
142
|
if isinstance(chunk, BaseMessageChunk):
|
|
134
143
|
chunk_gen = ChatGenerationChunk(message=chunk)
|
|
135
144
|
elif isinstance(chunk, BaseMessage):
|
|
136
145
|
chunk_gen = ChatGenerationChunk(
|
|
137
146
|
message=BaseMessageChunk(**chunk.model_dump())
|
|
138
147
|
)
|
|
148
|
+
elif isinstance(chunk, AIMessageChunk):
|
|
149
|
+
chunk_gen = chunk
|
|
150
|
+
elif isinstance(chunk, AIMessage):
|
|
151
|
+
chunk_gen = AIMessageChunk(
|
|
152
|
+
content=chunk.content,
|
|
153
|
+
id=chunk.id,
|
|
154
|
+
name=chunk.name,
|
|
155
|
+
lc_version=chunk.lc_version,
|
|
156
|
+
response_metadata=chunk.response_metadata,
|
|
157
|
+
usage_metadata=chunk.usage_metadata,
|
|
158
|
+
parsed=chunk.parsed,
|
|
159
|
+
)
|
|
139
160
|
else:
|
|
140
161
|
chunk_gen = GenerationChunk(text=chunk)
|
|
141
162
|
|
|
142
163
|
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
|
|
143
164
|
|
|
144
|
-
|
|
165
|
+
if isinstance(acc_gen, AIMessageChunk):
|
|
166
|
+
parsed = self.parse_result(acc_gen, partial=True)
|
|
167
|
+
else:
|
|
168
|
+
parsed = self.parse_result([acc_gen], partial=True)
|
|
145
169
|
if parsed is not None and parsed != prev_parsed:
|
|
146
170
|
if self.diff:
|
|
147
171
|
yield self._diff(prev_parsed, parsed)
|
|
@@ -151,24 +175,41 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|
|
151
175
|
|
|
152
176
|
@override
|
|
153
177
|
async def _atransform(
|
|
154
|
-
self, input: AsyncIterator[Union[str, BaseMessage]]
|
|
178
|
+
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
|
|
155
179
|
) -> AsyncIterator[T]:
|
|
156
180
|
prev_parsed = None
|
|
157
|
-
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] =
|
|
181
|
+
acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
|
|
182
|
+
None
|
|
183
|
+
)
|
|
158
184
|
async for chunk in input:
|
|
159
|
-
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
|
|
185
|
+
chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
|
160
186
|
if isinstance(chunk, BaseMessageChunk):
|
|
161
187
|
chunk_gen = ChatGenerationChunk(message=chunk)
|
|
162
188
|
elif isinstance(chunk, BaseMessage):
|
|
163
189
|
chunk_gen = ChatGenerationChunk(
|
|
164
190
|
message=BaseMessageChunk(**chunk.model_dump())
|
|
165
191
|
)
|
|
192
|
+
elif isinstance(chunk, AIMessageChunk):
|
|
193
|
+
chunk_gen = chunk
|
|
194
|
+
elif isinstance(chunk, AIMessage):
|
|
195
|
+
chunk_gen = AIMessageChunk(
|
|
196
|
+
content=chunk.content,
|
|
197
|
+
id=chunk.id,
|
|
198
|
+
name=chunk.name,
|
|
199
|
+
lc_version=chunk.lc_version,
|
|
200
|
+
response_metadata=chunk.response_metadata,
|
|
201
|
+
usage_metadata=chunk.usage_metadata,
|
|
202
|
+
parsed=chunk.parsed,
|
|
203
|
+
)
|
|
166
204
|
else:
|
|
167
205
|
chunk_gen = GenerationChunk(text=chunk)
|
|
168
206
|
|
|
169
207
|
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
|
|
170
208
|
|
|
171
|
-
|
|
209
|
+
if isinstance(acc_gen, AIMessageChunk):
|
|
210
|
+
parsed = await self.aparse_result(acc_gen, partial=True)
|
|
211
|
+
else:
|
|
212
|
+
parsed = await self.aparse_result([acc_gen], partial=True)
|
|
172
213
|
if parsed is not None and parsed != prev_parsed:
|
|
173
214
|
if self.diff:
|
|
174
215
|
yield await run_in_executor(None, self._diff, prev_parsed, parsed)
|
|
@@ -12,8 +12,10 @@ from typing_extensions import override
|
|
|
12
12
|
|
|
13
13
|
from langchain_core.exceptions import OutputParserException
|
|
14
14
|
from langchain_core.messages import BaseMessage
|
|
15
|
+
from langchain_core.messages.utils import convert_from_v1_message
|
|
15
16
|
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
|
16
17
|
from langchain_core.runnables.utils import AddableDict
|
|
18
|
+
from langchain_core.v1.messages import AIMessage
|
|
17
19
|
|
|
18
20
|
XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file.
|
|
19
21
|
1. Output should conform to the tags below.
|
|
@@ -240,21 +242,28 @@ class XMLOutputParser(BaseTransformOutputParser):
|
|
|
240
242
|
|
|
241
243
|
@override
|
|
242
244
|
def _transform(
|
|
243
|
-
self, input: Iterator[Union[str, BaseMessage]]
|
|
245
|
+
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
|
|
244
246
|
) -> Iterator[AddableDict]:
|
|
245
247
|
streaming_parser = _StreamingParser(self.parser)
|
|
246
248
|
for chunk in input:
|
|
247
|
-
|
|
249
|
+
if isinstance(chunk, AIMessage):
|
|
250
|
+
yield from streaming_parser.parse(convert_from_v1_message(chunk))
|
|
251
|
+
else:
|
|
252
|
+
yield from streaming_parser.parse(chunk)
|
|
248
253
|
streaming_parser.close()
|
|
249
254
|
|
|
250
255
|
@override
|
|
251
256
|
async def _atransform(
|
|
252
|
-
self, input: AsyncIterator[Union[str, BaseMessage]]
|
|
257
|
+
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
|
|
253
258
|
) -> AsyncIterator[AddableDict]:
|
|
254
259
|
streaming_parser = _StreamingParser(self.parser)
|
|
255
260
|
async for chunk in input:
|
|
256
|
-
|
|
257
|
-
|
|
261
|
+
if isinstance(chunk, AIMessage):
|
|
262
|
+
for output in streaming_parser.parse(convert_from_v1_message(chunk)):
|
|
263
|
+
yield output
|
|
264
|
+
else:
|
|
265
|
+
for output in streaming_parser.parse(chunk):
|
|
266
|
+
yield output
|
|
258
267
|
streaming_parser.close()
|
|
259
268
|
|
|
260
269
|
def _root_to_dict(self, root: ET.Element) -> dict[str, Union[str, list[Any]]]:
|
|
@@ -46,7 +46,10 @@ class LLMResult(BaseModel):
|
|
|
46
46
|
relevant information from standardized fields present in AIMessage.
|
|
47
47
|
"""
|
|
48
48
|
run: Optional[list[RunInfo]] = None
|
|
49
|
-
"""List of metadata info for model call for each input.
|
|
49
|
+
"""List of metadata info for model call for each input.
|
|
50
|
+
|
|
51
|
+
See :class:`~langchain_core.outputs.run_info.RunInfo` for details.
|
|
52
|
+
"""
|
|
50
53
|
|
|
51
54
|
type: Literal["LLMResult"] = "LLMResult"
|
|
52
55
|
"""Type is used exclusively for serialization purposes."""
|