langchain 1.0.0a3__py3-none-any.whl → 1.0.0a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/_internal/_lazy_import.py +2 -3
- langchain/_internal/_prompts.py +11 -18
- langchain/_internal/_typing.py +3 -3
- langchain/agents/_internal/_typing.py +2 -2
- langchain/agents/interrupt.py +14 -9
- langchain/agents/middleware/__init__.py +15 -0
- langchain/agents/middleware/_utils.py +11 -0
- langchain/agents/middleware/human_in_the_loop.py +135 -0
- langchain/agents/middleware/prompt_caching.py +62 -0
- langchain/agents/middleware/summarization.py +248 -0
- langchain/agents/middleware/types.py +79 -0
- langchain/agents/middleware_agent.py +557 -0
- langchain/agents/react_agent.py +114 -61
- langchain/agents/structured_output.py +29 -24
- langchain/agents/tool_node.py +71 -65
- langchain/chat_models/__init__.py +2 -0
- langchain/chat_models/base.py +30 -32
- langchain/documents/__init__.py +2 -0
- langchain/embeddings/__init__.py +2 -0
- langchain/embeddings/base.py +6 -10
- langchain/embeddings/cache.py +5 -8
- langchain/storage/encoder_backed.py +9 -4
- langchain/storage/exceptions.py +2 -0
- langchain/tools/__init__.py +2 -0
- {langchain-1.0.0a3.dist-info → langchain-1.0.0a5.dist-info}/METADATA +14 -18
- langchain-1.0.0a5.dist-info/RECORD +40 -0
- langchain-1.0.0a3.dist-info/RECORD +0 -33
- {langchain-1.0.0a3.dist-info → langchain-1.0.0a5.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a3.dist-info → langchain-1.0.0a5.dist-info}/entry_points.txt +0 -0
- {langchain-1.0.0a3.dist-info → langchain-1.0.0a5.dist-info}/licenses/LICENSE +0 -0
langchain/agents/tool_node.py
CHANGED
|
@@ -23,10 +23,12 @@ Typical Usage:
|
|
|
23
23
|
from langchain_core.tools import tool
|
|
24
24
|
from langchain.agents import ToolNode
|
|
25
25
|
|
|
26
|
+
|
|
26
27
|
@tool
|
|
27
28
|
def my_tool(x: int) -> str:
|
|
28
29
|
return f"Result: {x}"
|
|
29
30
|
|
|
31
|
+
|
|
30
32
|
tool_node = ToolNode([my_tool])
|
|
31
33
|
```
|
|
32
34
|
"""
|
|
@@ -38,6 +40,7 @@ import inspect
|
|
|
38
40
|
import json
|
|
39
41
|
from copy import copy, deepcopy
|
|
40
42
|
from dataclasses import replace
|
|
43
|
+
from types import UnionType
|
|
41
44
|
from typing import (
|
|
42
45
|
TYPE_CHECKING,
|
|
43
46
|
Annotated,
|
|
@@ -85,11 +88,19 @@ INVALID_TOOL_NAME_ERROR_TEMPLATE = (
|
|
|
85
88
|
"Error: {requested_tool} is not a valid tool, try one of [{available_tools}]."
|
|
86
89
|
)
|
|
87
90
|
TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
|
88
|
-
TOOL_EXECUTION_ERROR_TEMPLATE =
|
|
89
|
-
|
|
91
|
+
TOOL_EXECUTION_ERROR_TEMPLATE = (
|
|
92
|
+
"Error executing tool '{tool_name}' with kwargs {tool_kwargs} with error:\n"
|
|
93
|
+
" {error}\n"
|
|
94
|
+
" Please fix the error and try again."
|
|
95
|
+
)
|
|
96
|
+
TOOL_INVOCATION_ERROR_TEMPLATE = (
|
|
97
|
+
"Error invoking tool '{tool_name}' with kwargs {tool_kwargs} with error:\n"
|
|
98
|
+
" {error}\n"
|
|
99
|
+
" Please fix the error and try again."
|
|
100
|
+
)
|
|
90
101
|
|
|
91
102
|
|
|
92
|
-
def msg_content_output(output: Any) ->
|
|
103
|
+
def msg_content_output(output: Any) -> str | list[dict]:
|
|
93
104
|
"""Convert tool output to valid message content format.
|
|
94
105
|
|
|
95
106
|
LangChain ToolMessages accept either string content or a list of content blocks.
|
|
@@ -159,13 +170,7 @@ def _default_handle_tool_errors(e: Exception) -> str:
|
|
|
159
170
|
def _handle_tool_error(
|
|
160
171
|
e: Exception,
|
|
161
172
|
*,
|
|
162
|
-
flag:
|
|
163
|
-
bool,
|
|
164
|
-
str,
|
|
165
|
-
Callable[..., str],
|
|
166
|
-
type[Exception],
|
|
167
|
-
tuple[type[Exception], ...],
|
|
168
|
-
],
|
|
173
|
+
flag: bool | str | Callable[..., str] | type[Exception] | tuple[type[Exception], ...],
|
|
169
174
|
) -> str:
|
|
170
175
|
"""Generate error message content based on exception handling configuration.
|
|
171
176
|
|
|
@@ -242,7 +247,7 @@ def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception],
|
|
|
242
247
|
type_hints = get_type_hints(handler)
|
|
243
248
|
if first_param.name in type_hints:
|
|
244
249
|
origin = get_origin(first_param.annotation)
|
|
245
|
-
if origin
|
|
250
|
+
if origin in [Union, UnionType]:
|
|
246
251
|
args = get_args(first_param.annotation)
|
|
247
252
|
if all(issubclass(arg, Exception) for arg in args):
|
|
248
253
|
return tuple(args)
|
|
@@ -317,7 +322,8 @@ class ToolNode(RunnableCallable):
|
|
|
317
322
|
error template containing the exception details.
|
|
318
323
|
- **str**: Catch all errors and return a ToolMessage with this custom
|
|
319
324
|
error message string.
|
|
320
|
-
- **type[Exception]**: Only catch exceptions with the specified type and
|
|
325
|
+
- **type[Exception]**: Only catch exceptions with the specified type and
|
|
326
|
+
return the default error message for it.
|
|
321
327
|
- **tuple[type[Exception], ...]**: Only catch exceptions with the specified
|
|
322
328
|
types and return default error messages for them.
|
|
323
329
|
- **Callable[..., str]**: Catch exceptions matching the callable's signature
|
|
@@ -369,21 +375,24 @@ class ToolNode(RunnableCallable):
|
|
|
369
375
|
def handle_errors(e: ValueError) -> str:
|
|
370
376
|
return "Invalid input provided"
|
|
371
377
|
|
|
378
|
+
|
|
372
379
|
tool_node = ToolNode([my_tool], handle_tool_errors=handle_errors)
|
|
373
380
|
```
|
|
374
|
-
"""
|
|
381
|
+
""" # noqa: E501
|
|
375
382
|
|
|
376
383
|
name: str = "tools"
|
|
377
384
|
|
|
378
385
|
def __init__(
|
|
379
386
|
self,
|
|
380
|
-
tools: Sequence[
|
|
387
|
+
tools: Sequence[BaseTool | Callable],
|
|
381
388
|
*,
|
|
382
389
|
name: str = "tools",
|
|
383
390
|
tags: list[str] | None = None,
|
|
384
|
-
handle_tool_errors:
|
|
385
|
-
|
|
386
|
-
|
|
391
|
+
handle_tool_errors: bool
|
|
392
|
+
| str
|
|
393
|
+
| Callable[..., str]
|
|
394
|
+
| type[Exception]
|
|
395
|
+
| tuple[type[Exception], ...] = _default_handle_tool_errors,
|
|
387
396
|
messages_key: str = "messages",
|
|
388
397
|
) -> None:
|
|
389
398
|
"""Initialize the ToolNode with the provided tools and configuration.
|
|
@@ -417,11 +426,7 @@ class ToolNode(RunnableCallable):
|
|
|
417
426
|
|
|
418
427
|
def _func(
|
|
419
428
|
self,
|
|
420
|
-
input:
|
|
421
|
-
list[AnyMessage],
|
|
422
|
-
dict[str, Any],
|
|
423
|
-
BaseModel,
|
|
424
|
-
],
|
|
429
|
+
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
|
425
430
|
config: RunnableConfig,
|
|
426
431
|
*,
|
|
427
432
|
store: Optional[BaseStore], # noqa: UP045
|
|
@@ -436,11 +441,7 @@ class ToolNode(RunnableCallable):
|
|
|
436
441
|
|
|
437
442
|
async def _afunc(
|
|
438
443
|
self,
|
|
439
|
-
input:
|
|
440
|
-
list[AnyMessage],
|
|
441
|
-
dict[str, Any],
|
|
442
|
-
BaseModel,
|
|
443
|
-
],
|
|
444
|
+
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
|
444
445
|
config: RunnableConfig,
|
|
445
446
|
*,
|
|
446
447
|
store: Optional[BaseStore], # noqa: UP045
|
|
@@ -454,9 +455,9 @@ class ToolNode(RunnableCallable):
|
|
|
454
455
|
|
|
455
456
|
def _combine_tool_outputs(
|
|
456
457
|
self,
|
|
457
|
-
outputs: list[
|
|
458
|
+
outputs: list[ToolMessage | Command],
|
|
458
459
|
input_type: Literal["list", "dict", "tool_calls"],
|
|
459
|
-
) -> list[
|
|
460
|
+
) -> list[Command | list[ToolMessage] | dict[str, list[ToolMessage]]]:
|
|
460
461
|
# preserve existing behavior for non-command tool outputs for backwards
|
|
461
462
|
# compatibility
|
|
462
463
|
if not any(isinstance(output, Command) for output in outputs):
|
|
@@ -499,7 +500,7 @@ class ToolNode(RunnableCallable):
|
|
|
499
500
|
call: ToolCall,
|
|
500
501
|
input_type: Literal["list", "dict", "tool_calls"],
|
|
501
502
|
config: RunnableConfig,
|
|
502
|
-
) ->
|
|
503
|
+
) -> ToolMessage | Command:
|
|
503
504
|
"""Run a single tool call synchronously."""
|
|
504
505
|
if invalid_tool_message := self._validate_tool_call(call):
|
|
505
506
|
return invalid_tool_message
|
|
@@ -515,10 +516,12 @@ class ToolNode(RunnableCallable):
|
|
|
515
516
|
|
|
516
517
|
# GraphInterrupt is a special exception that will always be raised.
|
|
517
518
|
# It can be triggered in the following scenarios,
|
|
518
|
-
# Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation
|
|
519
|
+
# Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation
|
|
520
|
+
# most commonly:
|
|
519
521
|
# (1) a GraphInterrupt is raised inside a tool
|
|
520
522
|
# (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool
|
|
521
|
-
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph
|
|
523
|
+
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph
|
|
524
|
+
# called as a tool
|
|
522
525
|
# (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
|
|
523
526
|
except GraphBubbleUp:
|
|
524
527
|
raise
|
|
@@ -553,7 +556,7 @@ class ToolNode(RunnableCallable):
|
|
|
553
556
|
if isinstance(response, Command):
|
|
554
557
|
return self._validate_tool_command(response, call, input_type)
|
|
555
558
|
if isinstance(response, ToolMessage):
|
|
556
|
-
response.content = cast("
|
|
559
|
+
response.content = cast("str | list", msg_content_output(response.content))
|
|
557
560
|
return response
|
|
558
561
|
msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
|
|
559
562
|
raise TypeError(msg)
|
|
@@ -563,7 +566,7 @@ class ToolNode(RunnableCallable):
|
|
|
563
566
|
call: ToolCall,
|
|
564
567
|
input_type: Literal["list", "dict", "tool_calls"],
|
|
565
568
|
config: RunnableConfig,
|
|
566
|
-
) ->
|
|
569
|
+
) -> ToolMessage | Command:
|
|
567
570
|
"""Run a single tool call asynchronously."""
|
|
568
571
|
if invalid_tool_message := self._validate_tool_call(call):
|
|
569
572
|
return invalid_tool_message
|
|
@@ -579,10 +582,12 @@ class ToolNode(RunnableCallable):
|
|
|
579
582
|
|
|
580
583
|
# GraphInterrupt is a special exception that will always be raised.
|
|
581
584
|
# It can be triggered in the following scenarios,
|
|
582
|
-
# Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation
|
|
585
|
+
# Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation
|
|
586
|
+
# most commonly:
|
|
583
587
|
# (1) a GraphInterrupt is raised inside a tool
|
|
584
588
|
# (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool
|
|
585
|
-
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph
|
|
589
|
+
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph
|
|
590
|
+
# called as a tool
|
|
586
591
|
# (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
|
|
587
592
|
except GraphBubbleUp:
|
|
588
593
|
raise
|
|
@@ -618,18 +623,14 @@ class ToolNode(RunnableCallable):
|
|
|
618
623
|
if isinstance(response, Command):
|
|
619
624
|
return self._validate_tool_command(response, call, input_type)
|
|
620
625
|
if isinstance(response, ToolMessage):
|
|
621
|
-
response.content = cast("
|
|
626
|
+
response.content = cast("str | list", msg_content_output(response.content))
|
|
622
627
|
return response
|
|
623
628
|
msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
|
|
624
629
|
raise TypeError(msg)
|
|
625
630
|
|
|
626
631
|
def _parse_input(
|
|
627
632
|
self,
|
|
628
|
-
input:
|
|
629
|
-
list[AnyMessage],
|
|
630
|
-
dict[str, Any],
|
|
631
|
-
BaseModel,
|
|
632
|
-
],
|
|
633
|
+
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
|
633
634
|
store: BaseStore | None,
|
|
634
635
|
) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]:
|
|
635
636
|
input_type: Literal["list", "dict", "tool_calls"]
|
|
@@ -676,11 +677,7 @@ class ToolNode(RunnableCallable):
|
|
|
676
677
|
def _inject_state(
|
|
677
678
|
self,
|
|
678
679
|
tool_call: ToolCall,
|
|
679
|
-
input:
|
|
680
|
-
list[AnyMessage],
|
|
681
|
-
dict[str, Any],
|
|
682
|
-
BaseModel,
|
|
683
|
-
],
|
|
680
|
+
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
|
684
681
|
) -> ToolCall:
|
|
685
682
|
state_args = self._tool_to_state_args[tool_call["name"]]
|
|
686
683
|
if state_args and isinstance(input, list):
|
|
@@ -737,11 +734,7 @@ class ToolNode(RunnableCallable):
|
|
|
737
734
|
def inject_tool_args(
|
|
738
735
|
self,
|
|
739
736
|
tool_call: ToolCall,
|
|
740
|
-
input:
|
|
741
|
-
list[AnyMessage],
|
|
742
|
-
dict[str, Any],
|
|
743
|
-
BaseModel,
|
|
744
|
-
],
|
|
737
|
+
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
|
745
738
|
store: BaseStore | None,
|
|
746
739
|
) -> ToolCall:
|
|
747
740
|
"""Inject graph state and store into tool call arguments.
|
|
@@ -791,10 +784,12 @@ class ToolNode(RunnableCallable):
|
|
|
791
784
|
input_type: Literal["list", "dict", "tool_calls"],
|
|
792
785
|
) -> Command:
|
|
793
786
|
if isinstance(command.update, dict):
|
|
794
|
-
# input type is dict when ToolNode is invoked with a dict input
|
|
787
|
+
# input type is dict when ToolNode is invoked with a dict input
|
|
788
|
+
# (e.g. {"messages": [AIMessage(..., tool_calls=[...])]})
|
|
795
789
|
if input_type not in ("dict", "tool_calls"):
|
|
796
790
|
msg = (
|
|
797
|
-
|
|
791
|
+
"Tools can provide a dict in Command.update only when using dict "
|
|
792
|
+
f"with '{self._messages_key}' key as ToolNode input, "
|
|
798
793
|
f"got: {command.update} for tool '{call['name']}'"
|
|
799
794
|
)
|
|
800
795
|
raise ValueError(msg)
|
|
@@ -803,10 +798,12 @@ class ToolNode(RunnableCallable):
|
|
|
803
798
|
state_update = cast("dict[str, Any]", updated_command.update) or {}
|
|
804
799
|
messages_update = state_update.get(self._messages_key, [])
|
|
805
800
|
elif isinstance(command.update, list):
|
|
806
|
-
# Input type is list when ToolNode is invoked with a list input
|
|
801
|
+
# Input type is list when ToolNode is invoked with a list input
|
|
802
|
+
# (e.g. [AIMessage(..., tool_calls=[...])])
|
|
807
803
|
if input_type != "list":
|
|
808
804
|
msg = (
|
|
809
|
-
|
|
805
|
+
"Tools can provide a list of messages in Command.update "
|
|
806
|
+
"only when using list of messages as ToolNode input, "
|
|
810
807
|
f"got: {command.update} for tool '{call['name']}'"
|
|
811
808
|
)
|
|
812
809
|
raise ValueError(msg)
|
|
@@ -836,13 +833,17 @@ class ToolNode(RunnableCallable):
|
|
|
836
833
|
# Command.update if command is sent to the CURRENT graph
|
|
837
834
|
if updated_command.graph is None and not has_matching_tool_message:
|
|
838
835
|
example_update = (
|
|
839
|
-
'`Command(update={"messages":
|
|
836
|
+
'`Command(update={"messages": '
|
|
837
|
+
'[ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`'
|
|
840
838
|
if input_type == "dict"
|
|
841
|
-
else
|
|
839
|
+
else "`Command(update="
|
|
840
|
+
'[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`'
|
|
842
841
|
)
|
|
843
842
|
msg = (
|
|
844
|
-
|
|
845
|
-
"
|
|
843
|
+
"Expected to have a matching ToolMessage in Command.update "
|
|
844
|
+
f"for tool '{call['name']}', got: {messages_update}. "
|
|
845
|
+
"Every tool call (LLM requesting to call a tool) "
|
|
846
|
+
"in the message history MUST have a corresponding ToolMessage. "
|
|
846
847
|
f"You can fix it by modifying the tool to return {example_update}."
|
|
847
848
|
)
|
|
848
849
|
raise ValueError(msg)
|
|
@@ -850,7 +851,7 @@ class ToolNode(RunnableCallable):
|
|
|
850
851
|
|
|
851
852
|
|
|
852
853
|
def tools_condition(
|
|
853
|
-
state:
|
|
854
|
+
state: list[AnyMessage] | dict[str, Any] | BaseModel,
|
|
854
855
|
messages_key: str = "messages",
|
|
855
856
|
) -> Literal["tools", "__end__"]:
|
|
856
857
|
"""Conditional routing function for tool-calling workflows.
|
|
@@ -887,16 +888,18 @@ def tools_condition(
|
|
|
887
888
|
from langgraph.agents.tool_node import ToolNode, tools_condition
|
|
888
889
|
from typing_extensions import TypedDict
|
|
889
890
|
|
|
891
|
+
|
|
890
892
|
class State(TypedDict):
|
|
891
893
|
messages: list
|
|
892
894
|
|
|
895
|
+
|
|
893
896
|
graph = StateGraph(State)
|
|
894
897
|
graph.add_node("llm", call_model)
|
|
895
898
|
graph.add_node("tools", ToolNode([my_tool]))
|
|
896
899
|
graph.add_conditional_edges(
|
|
897
900
|
"llm",
|
|
898
901
|
tools_condition, # Routes to "tools" or "__end__"
|
|
899
|
-
{"tools": "tools", "__end__": "__end__"}
|
|
902
|
+
{"tools": "tools", "__end__": "__end__"},
|
|
900
903
|
)
|
|
901
904
|
```
|
|
902
905
|
|
|
@@ -956,6 +959,7 @@ class InjectedState(InjectedToolArg):
|
|
|
956
959
|
messages: List[BaseMessage]
|
|
957
960
|
foo: str
|
|
958
961
|
|
|
962
|
+
|
|
959
963
|
@tool
|
|
960
964
|
def state_tool(x: int, state: Annotated[dict, InjectedState]) -> str:
|
|
961
965
|
'''Do something with state.'''
|
|
@@ -964,11 +968,13 @@ class InjectedState(InjectedToolArg):
|
|
|
964
968
|
else:
|
|
965
969
|
return "not enough messages"
|
|
966
970
|
|
|
971
|
+
|
|
967
972
|
@tool
|
|
968
973
|
def foo_tool(x: int, foo: Annotated[str, InjectedState("foo")]) -> str:
|
|
969
974
|
'''Do something else with state.'''
|
|
970
975
|
return foo + str(x + 1)
|
|
971
976
|
|
|
977
|
+
|
|
972
978
|
node = ToolNode([state_tool, foo_tool])
|
|
973
979
|
|
|
974
980
|
tool_call1 = {"name": "state_tool", "args": {"x": 1}, "id": "1", "type": "tool_call"}
|
|
@@ -982,8 +988,8 @@ class InjectedState(InjectedToolArg):
|
|
|
982
988
|
|
|
983
989
|
```pycon
|
|
984
990
|
[
|
|
985
|
-
ToolMessage(content=
|
|
986
|
-
ToolMessage(content=
|
|
991
|
+
ToolMessage(content="not enough messages", name="state_tool", tool_call_id="1"),
|
|
992
|
+
ToolMessage(content="bar2", name="foo_tool", tool_call_id="2"),
|
|
987
993
|
]
|
|
988
994
|
```
|
|
989
995
|
|
|
@@ -1078,7 +1084,7 @@ class InjectedStore(InjectedToolArg):
|
|
|
1078
1084
|
"""
|
|
1079
1085
|
|
|
1080
1086
|
|
|
1081
|
-
def _is_injection(type_arg: Any, injection_type: type[
|
|
1087
|
+
def _is_injection(type_arg: Any, injection_type: type[InjectedState | InjectedStore]) -> bool:
|
|
1082
1088
|
"""Check if a type argument represents an injection annotation.
|
|
1083
1089
|
|
|
1084
1090
|
This utility function determines whether a type annotation indicates that
|
langchain/chat_models/base.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Factory functions for chat models."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
import warnings
|
|
@@ -7,7 +9,6 @@ from typing import (
|
|
|
7
9
|
Any,
|
|
8
10
|
Literal,
|
|
9
11
|
TypeAlias,
|
|
10
|
-
Union,
|
|
11
12
|
cast,
|
|
12
13
|
overload,
|
|
13
14
|
)
|
|
@@ -53,7 +54,7 @@ def init_chat_model(
|
|
|
53
54
|
model: str | None = None,
|
|
54
55
|
*,
|
|
55
56
|
model_provider: str | None = None,
|
|
56
|
-
configurable_fields:
|
|
57
|
+
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = ...,
|
|
57
58
|
config_prefix: str | None = None,
|
|
58
59
|
**kwargs: Any,
|
|
59
60
|
) -> _ConfigurableModel: ...
|
|
@@ -66,10 +67,10 @@ def init_chat_model(
|
|
|
66
67
|
model: str | None = None,
|
|
67
68
|
*,
|
|
68
69
|
model_provider: str | None = None,
|
|
69
|
-
configurable_fields:
|
|
70
|
+
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] | None = None,
|
|
70
71
|
config_prefix: str | None = None,
|
|
71
72
|
**kwargs: Any,
|
|
72
|
-
) ->
|
|
73
|
+
) -> BaseChatModel | _ConfigurableModel:
|
|
73
74
|
"""Initialize a ChatModel from the model name and provider.
|
|
74
75
|
|
|
75
76
|
**Note:** Must have the integration package corresponding to the model provider
|
|
@@ -189,14 +190,12 @@ def init_chat_model(
|
|
|
189
190
|
configurable_model = init_chat_model(temperature=0)
|
|
190
191
|
|
|
191
192
|
configurable_model.invoke(
|
|
192
|
-
"what's your name",
|
|
193
|
-
config={"configurable": {"model": "gpt-4o"}}
|
|
193
|
+
"what's your name", config={"configurable": {"model": "gpt-4o"}}
|
|
194
194
|
)
|
|
195
195
|
# GPT-4o response
|
|
196
196
|
|
|
197
197
|
configurable_model.invoke(
|
|
198
|
-
"what's your name",
|
|
199
|
-
config={"configurable": {"model": "claude-3-5-sonnet-latest"}}
|
|
198
|
+
"what's your name", config={"configurable": {"model": "claude-3-5-sonnet-latest"}}
|
|
200
199
|
)
|
|
201
200
|
# claude-3.5 sonnet response
|
|
202
201
|
|
|
@@ -211,7 +210,7 @@ def init_chat_model(
|
|
|
211
210
|
"openai:gpt-4o",
|
|
212
211
|
configurable_fields="any", # this allows us to configure other params like temperature, max_tokens, etc at runtime.
|
|
213
212
|
config_prefix="foo",
|
|
214
|
-
temperature=0
|
|
213
|
+
temperature=0,
|
|
215
214
|
)
|
|
216
215
|
|
|
217
216
|
configurable_model_with_default.invoke("what's your name")
|
|
@@ -222,9 +221,9 @@ def init_chat_model(
|
|
|
222
221
|
config={
|
|
223
222
|
"configurable": {
|
|
224
223
|
"foo_model": "anthropic:claude-3-5-sonnet-latest",
|
|
225
|
-
"foo_temperature": 0.6
|
|
224
|
+
"foo_temperature": 0.6,
|
|
226
225
|
}
|
|
227
|
-
}
|
|
226
|
+
},
|
|
228
227
|
)
|
|
229
228
|
# Claude-3.5 sonnet response with temperature 0.6
|
|
230
229
|
|
|
@@ -239,23 +238,26 @@ def init_chat_model(
|
|
|
239
238
|
from langchain.chat_models import init_chat_model
|
|
240
239
|
from pydantic import BaseModel, Field
|
|
241
240
|
|
|
241
|
+
|
|
242
242
|
class GetWeather(BaseModel):
|
|
243
243
|
'''Get the current weather in a given location'''
|
|
244
244
|
|
|
245
245
|
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
|
246
246
|
|
|
247
|
+
|
|
247
248
|
class GetPopulation(BaseModel):
|
|
248
249
|
'''Get the current population in a given location'''
|
|
249
250
|
|
|
250
251
|
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
|
251
252
|
|
|
253
|
+
|
|
252
254
|
configurable_model = init_chat_model(
|
|
253
|
-
"gpt-4o",
|
|
254
|
-
configurable_fields=("model", "model_provider"),
|
|
255
|
-
temperature=0
|
|
255
|
+
"gpt-4o", configurable_fields=("model", "model_provider"), temperature=0
|
|
256
256
|
)
|
|
257
257
|
|
|
258
|
-
configurable_model_with_tools = configurable_model.bind_tools(
|
|
258
|
+
configurable_model_with_tools = configurable_model.bind_tools(
|
|
259
|
+
[GetWeather, GetPopulation]
|
|
260
|
+
)
|
|
259
261
|
configurable_model_with_tools.invoke(
|
|
260
262
|
"Which city is hotter today and which is bigger: LA or NY?"
|
|
261
263
|
)
|
|
@@ -263,7 +265,7 @@ def init_chat_model(
|
|
|
263
265
|
|
|
264
266
|
configurable_model_with_tools.invoke(
|
|
265
267
|
"Which city is hotter today and which is bigger: LA or NY?",
|
|
266
|
-
config={"configurable": {"model": "claude-3-5-sonnet-latest"}}
|
|
268
|
+
config={"configurable": {"model": "claude-3-5-sonnet-latest"}},
|
|
267
269
|
)
|
|
268
270
|
# Claude-3.5 sonnet response with tools
|
|
269
271
|
|
|
@@ -528,12 +530,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
528
530
|
self,
|
|
529
531
|
*,
|
|
530
532
|
default_config: dict | None = None,
|
|
531
|
-
configurable_fields:
|
|
533
|
+
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = "any",
|
|
532
534
|
config_prefix: str = "",
|
|
533
535
|
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
|
|
534
536
|
) -> None:
|
|
535
537
|
self._default_config: dict = default_config or {}
|
|
536
|
-
self._configurable_fields:
|
|
538
|
+
self._configurable_fields: Literal["any"] | list[str] = (
|
|
537
539
|
configurable_fields if configurable_fields == "any" else list(configurable_fields)
|
|
538
540
|
)
|
|
539
541
|
self._config_prefix = (
|
|
@@ -636,11 +638,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
636
638
|
# This is a version of LanguageModelInput which replaces the abstract
|
|
637
639
|
# base class BaseMessage with a union of its subclasses, which makes
|
|
638
640
|
# for a much better schema.
|
|
639
|
-
return
|
|
640
|
-
str,
|
|
641
|
-
Union[StringPromptValue, ChatPromptValueConcrete],
|
|
642
|
-
list[AnyMessage],
|
|
643
|
-
]
|
|
641
|
+
return str | StringPromptValue | ChatPromptValueConcrete | list[AnyMessage]
|
|
644
642
|
|
|
645
643
|
@override
|
|
646
644
|
def invoke(
|
|
@@ -682,7 +680,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
682
680
|
def batch(
|
|
683
681
|
self,
|
|
684
682
|
inputs: list[LanguageModelInput],
|
|
685
|
-
config:
|
|
683
|
+
config: RunnableConfig | list[RunnableConfig] | None = None,
|
|
686
684
|
*,
|
|
687
685
|
return_exceptions: bool = False,
|
|
688
686
|
**kwargs: Any | None,
|
|
@@ -710,7 +708,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
710
708
|
async def abatch(
|
|
711
709
|
self,
|
|
712
710
|
inputs: list[LanguageModelInput],
|
|
713
|
-
config:
|
|
711
|
+
config: RunnableConfig | list[RunnableConfig] | None = None,
|
|
714
712
|
*,
|
|
715
713
|
return_exceptions: bool = False,
|
|
716
714
|
**kwargs: Any | None,
|
|
@@ -738,11 +736,11 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
738
736
|
def batch_as_completed(
|
|
739
737
|
self,
|
|
740
738
|
inputs: Sequence[LanguageModelInput],
|
|
741
|
-
config:
|
|
739
|
+
config: RunnableConfig | Sequence[RunnableConfig] | None = None,
|
|
742
740
|
*,
|
|
743
741
|
return_exceptions: bool = False,
|
|
744
742
|
**kwargs: Any,
|
|
745
|
-
) -> Iterator[tuple[int,
|
|
743
|
+
) -> Iterator[tuple[int, Any | Exception]]:
|
|
746
744
|
config = config or None
|
|
747
745
|
# If <= 1 config use the underlying models batch implementation.
|
|
748
746
|
if config is None or isinstance(config, dict) or len(config) <= 1:
|
|
@@ -767,7 +765,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
767
765
|
async def abatch_as_completed(
|
|
768
766
|
self,
|
|
769
767
|
inputs: Sequence[LanguageModelInput],
|
|
770
|
-
config:
|
|
768
|
+
config: RunnableConfig | Sequence[RunnableConfig] | None = None,
|
|
771
769
|
*,
|
|
772
770
|
return_exceptions: bool = False,
|
|
773
771
|
**kwargs: Any,
|
|
@@ -865,7 +863,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
865
863
|
exclude_types: Sequence[str] | None = None,
|
|
866
864
|
exclude_tags: Sequence[str] | None = None,
|
|
867
865
|
**kwargs: Any,
|
|
868
|
-
) ->
|
|
866
|
+
) -> AsyncIterator[RunLogPatch] | AsyncIterator[RunLog]:
|
|
869
867
|
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
|
|
870
868
|
input,
|
|
871
869
|
config=config,
|
|
@@ -913,7 +911,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
913
911
|
# Explicitly added to satisfy downstream linters.
|
|
914
912
|
def bind_tools(
|
|
915
913
|
self,
|
|
916
|
-
tools: Sequence[
|
|
914
|
+
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
|
|
917
915
|
**kwargs: Any,
|
|
918
916
|
) -> Runnable[LanguageModelInput, AIMessage]:
|
|
919
917
|
return self.__getattr__("bind_tools")(tools, **kwargs)
|
|
@@ -921,7 +919,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
921
919
|
# Explicitly added to satisfy downstream linters.
|
|
922
920
|
def with_structured_output(
|
|
923
921
|
self,
|
|
924
|
-
schema:
|
|
922
|
+
schema: dict | type[BaseModel],
|
|
925
923
|
**kwargs: Any,
|
|
926
|
-
) -> Runnable[LanguageModelInput,
|
|
924
|
+
) -> Runnable[LanguageModelInput, dict | BaseModel]:
|
|
927
925
|
return self.__getattr__("with_structured_output")(schema, **kwargs)
|
langchain/documents/__init__.py
CHANGED
langchain/embeddings/__init__.py
CHANGED
langchain/embeddings/base.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
"""Factory functions for embeddings."""
|
|
2
|
+
|
|
1
3
|
import functools
|
|
2
4
|
from importlib import util
|
|
3
|
-
from typing import Any
|
|
5
|
+
from typing import Any
|
|
4
6
|
|
|
5
7
|
from langchain_core.embeddings import Embeddings
|
|
6
8
|
from langchain_core.runnables import Runnable
|
|
@@ -124,7 +126,7 @@ def init_embeddings(
|
|
|
124
126
|
*,
|
|
125
127
|
provider: str | None = None,
|
|
126
128
|
**kwargs: Any,
|
|
127
|
-
) ->
|
|
129
|
+
) -> Embeddings | Runnable[Any, list[float]]:
|
|
128
130
|
"""Initialize an embeddings model from a model name and optional provider.
|
|
129
131
|
|
|
130
132
|
**Note:** Must have the integration package corresponding to the model provider
|
|
@@ -160,17 +162,11 @@ def init_embeddings(
|
|
|
160
162
|
model.embed_query("Hello, world!")
|
|
161
163
|
|
|
162
164
|
# Using explicit provider
|
|
163
|
-
model = init_embeddings(
|
|
164
|
-
model="text-embedding-3-small",
|
|
165
|
-
provider="openai"
|
|
166
|
-
)
|
|
165
|
+
model = init_embeddings(model="text-embedding-3-small", provider="openai")
|
|
167
166
|
model.embed_documents(["Hello, world!", "Goodbye, world!"])
|
|
168
167
|
|
|
169
168
|
# With additional parameters
|
|
170
|
-
model = init_embeddings(
|
|
171
|
-
"openai:text-embedding-3-small",
|
|
172
|
-
api_key="sk-..."
|
|
173
|
-
)
|
|
169
|
+
model = init_embeddings("openai:text-embedding-3-small", api_key="sk-...")
|
|
174
170
|
|
|
175
171
|
.. versionadded:: 0.3.9
|
|
176
172
|
|
langchain/embeddings/cache.py
CHANGED
|
@@ -13,7 +13,7 @@ import hashlib
|
|
|
13
13
|
import json
|
|
14
14
|
import uuid
|
|
15
15
|
import warnings
|
|
16
|
-
from typing import TYPE_CHECKING, Literal,
|
|
16
|
+
from typing import TYPE_CHECKING, Literal, cast
|
|
17
17
|
|
|
18
18
|
from langchain_core.embeddings import Embeddings
|
|
19
19
|
from langchain_core.utils.iter import batch_iterate
|
|
@@ -178,7 +178,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|
|
178
178
|
Returns:
|
|
179
179
|
A list of embeddings for the given texts.
|
|
180
180
|
"""
|
|
181
|
-
vectors: list[
|
|
181
|
+
vectors: list[list[float] | None] = self.document_embedding_store.mget(
|
|
182
182
|
texts,
|
|
183
183
|
)
|
|
184
184
|
all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
|
|
@@ -210,7 +210,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|
|
210
210
|
Returns:
|
|
211
211
|
A list of embeddings for the given texts.
|
|
212
212
|
"""
|
|
213
|
-
vectors: list[
|
|
213
|
+
vectors: list[list[float] | None] = await self.document_embedding_store.amget(texts)
|
|
214
214
|
all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
|
|
215
215
|
|
|
216
216
|
# batch_iterate supports None batch_size which returns all elements at once
|
|
@@ -285,11 +285,8 @@ class CacheBackedEmbeddings(Embeddings):
|
|
|
285
285
|
*,
|
|
286
286
|
namespace: str = "",
|
|
287
287
|
batch_size: int | None = None,
|
|
288
|
-
query_embedding_cache:
|
|
289
|
-
key_encoder:
|
|
290
|
-
Callable[[str], str],
|
|
291
|
-
Literal["sha1", "blake2b", "sha256", "sha512"],
|
|
292
|
-
] = "sha1",
|
|
288
|
+
query_embedding_cache: bool | ByteStore = False,
|
|
289
|
+
key_encoder: Callable[[str], str] | Literal["sha1", "blake2b", "sha256", "sha512"] = "sha1",
|
|
293
290
|
) -> CacheBackedEmbeddings:
|
|
294
291
|
"""On-ramp that adds the necessary serialization and encoding to the store.
|
|
295
292
|
|