langchain 1.0.0a4__py3-none-any.whl → 1.0.0a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/__init__.py +1 -1
- langchain/_internal/_lazy_import.py +2 -3
- langchain/_internal/_prompts.py +11 -18
- langchain/_internal/_typing.py +3 -3
- langchain/agents/_internal/_typing.py +2 -2
- langchain/agents/middleware/__init__.py +3 -0
- langchain/agents/middleware/dynamic_system_prompt.py +105 -0
- langchain/agents/middleware/human_in_the_loop.py +213 -88
- langchain/agents/middleware/prompt_caching.py +16 -8
- langchain/agents/middleware/summarization.py +2 -2
- langchain/agents/middleware/types.py +52 -11
- langchain/agents/middleware_agent.py +151 -94
- langchain/agents/react_agent.py +86 -61
- langchain/agents/structured_output.py +29 -24
- langchain/agents/tool_node.py +71 -65
- langchain/chat_models/base.py +28 -32
- langchain/embeddings/base.py +4 -10
- langchain/embeddings/cache.py +5 -8
- langchain/storage/encoder_backed.py +7 -4
- {langchain-1.0.0a4.dist-info → langchain-1.0.0a6.dist-info}/METADATA +17 -17
- langchain-1.0.0a6.dist-info/RECORD +39 -0
- langchain/agents/interrupt.py +0 -92
- langchain/agents/middleware/_utils.py +0 -11
- langchain-1.0.0a4.dist-info/RECORD +0 -40
- {langchain-1.0.0a4.dist-info → langchain-1.0.0a6.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a4.dist-info → langchain-1.0.0a6.dist-info}/entry_points.txt +0 -0
- {langchain-1.0.0a4.dist-info → langchain-1.0.0a6.dist-info}/licenses/LICENSE +0 -0
langchain/agents/tool_node.py
CHANGED
|
@@ -23,10 +23,12 @@ Typical Usage:
|
|
|
23
23
|
from langchain_core.tools import tool
|
|
24
24
|
from langchain.agents import ToolNode
|
|
25
25
|
|
|
26
|
+
|
|
26
27
|
@tool
|
|
27
28
|
def my_tool(x: int) -> str:
|
|
28
29
|
return f"Result: {x}"
|
|
29
30
|
|
|
31
|
+
|
|
30
32
|
tool_node = ToolNode([my_tool])
|
|
31
33
|
```
|
|
32
34
|
"""
|
|
@@ -38,6 +40,7 @@ import inspect
|
|
|
38
40
|
import json
|
|
39
41
|
from copy import copy, deepcopy
|
|
40
42
|
from dataclasses import replace
|
|
43
|
+
from types import UnionType
|
|
41
44
|
from typing import (
|
|
42
45
|
TYPE_CHECKING,
|
|
43
46
|
Annotated,
|
|
@@ -85,11 +88,19 @@ INVALID_TOOL_NAME_ERROR_TEMPLATE = (
|
|
|
85
88
|
"Error: {requested_tool} is not a valid tool, try one of [{available_tools}]."
|
|
86
89
|
)
|
|
87
90
|
TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
|
88
|
-
TOOL_EXECUTION_ERROR_TEMPLATE =
|
|
89
|
-
|
|
91
|
+
TOOL_EXECUTION_ERROR_TEMPLATE = (
|
|
92
|
+
"Error executing tool '{tool_name}' with kwargs {tool_kwargs} with error:\n"
|
|
93
|
+
" {error}\n"
|
|
94
|
+
" Please fix the error and try again."
|
|
95
|
+
)
|
|
96
|
+
TOOL_INVOCATION_ERROR_TEMPLATE = (
|
|
97
|
+
"Error invoking tool '{tool_name}' with kwargs {tool_kwargs} with error:\n"
|
|
98
|
+
" {error}\n"
|
|
99
|
+
" Please fix the error and try again."
|
|
100
|
+
)
|
|
90
101
|
|
|
91
102
|
|
|
92
|
-
def msg_content_output(output: Any) ->
|
|
103
|
+
def msg_content_output(output: Any) -> str | list[dict]:
|
|
93
104
|
"""Convert tool output to valid message content format.
|
|
94
105
|
|
|
95
106
|
LangChain ToolMessages accept either string content or a list of content blocks.
|
|
@@ -159,13 +170,7 @@ def _default_handle_tool_errors(e: Exception) -> str:
|
|
|
159
170
|
def _handle_tool_error(
|
|
160
171
|
e: Exception,
|
|
161
172
|
*,
|
|
162
|
-
flag:
|
|
163
|
-
bool,
|
|
164
|
-
str,
|
|
165
|
-
Callable[..., str],
|
|
166
|
-
type[Exception],
|
|
167
|
-
tuple[type[Exception], ...],
|
|
168
|
-
],
|
|
173
|
+
flag: bool | str | Callable[..., str] | type[Exception] | tuple[type[Exception], ...],
|
|
169
174
|
) -> str:
|
|
170
175
|
"""Generate error message content based on exception handling configuration.
|
|
171
176
|
|
|
@@ -242,7 +247,7 @@ def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception],
|
|
|
242
247
|
type_hints = get_type_hints(handler)
|
|
243
248
|
if first_param.name in type_hints:
|
|
244
249
|
origin = get_origin(first_param.annotation)
|
|
245
|
-
if origin
|
|
250
|
+
if origin in [Union, UnionType]:
|
|
246
251
|
args = get_args(first_param.annotation)
|
|
247
252
|
if all(issubclass(arg, Exception) for arg in args):
|
|
248
253
|
return tuple(args)
|
|
@@ -317,7 +322,8 @@ class ToolNode(RunnableCallable):
|
|
|
317
322
|
error template containing the exception details.
|
|
318
323
|
- **str**: Catch all errors and return a ToolMessage with this custom
|
|
319
324
|
error message string.
|
|
320
|
-
- **type[Exception]**: Only catch exceptions with the specified type and
|
|
325
|
+
- **type[Exception]**: Only catch exceptions with the specified type and
|
|
326
|
+
return the default error message for it.
|
|
321
327
|
- **tuple[type[Exception], ...]**: Only catch exceptions with the specified
|
|
322
328
|
types and return default error messages for them.
|
|
323
329
|
- **Callable[..., str]**: Catch exceptions matching the callable's signature
|
|
@@ -369,21 +375,24 @@ class ToolNode(RunnableCallable):
|
|
|
369
375
|
def handle_errors(e: ValueError) -> str:
|
|
370
376
|
return "Invalid input provided"
|
|
371
377
|
|
|
378
|
+
|
|
372
379
|
tool_node = ToolNode([my_tool], handle_tool_errors=handle_errors)
|
|
373
380
|
```
|
|
374
|
-
"""
|
|
381
|
+
""" # noqa: E501
|
|
375
382
|
|
|
376
383
|
name: str = "tools"
|
|
377
384
|
|
|
378
385
|
def __init__(
|
|
379
386
|
self,
|
|
380
|
-
tools: Sequence[
|
|
387
|
+
tools: Sequence[BaseTool | Callable],
|
|
381
388
|
*,
|
|
382
389
|
name: str = "tools",
|
|
383
390
|
tags: list[str] | None = None,
|
|
384
|
-
handle_tool_errors:
|
|
385
|
-
|
|
386
|
-
|
|
391
|
+
handle_tool_errors: bool
|
|
392
|
+
| str
|
|
393
|
+
| Callable[..., str]
|
|
394
|
+
| type[Exception]
|
|
395
|
+
| tuple[type[Exception], ...] = _default_handle_tool_errors,
|
|
387
396
|
messages_key: str = "messages",
|
|
388
397
|
) -> None:
|
|
389
398
|
"""Initialize the ToolNode with the provided tools and configuration.
|
|
@@ -417,11 +426,7 @@ class ToolNode(RunnableCallable):
|
|
|
417
426
|
|
|
418
427
|
def _func(
|
|
419
428
|
self,
|
|
420
|
-
input:
|
|
421
|
-
list[AnyMessage],
|
|
422
|
-
dict[str, Any],
|
|
423
|
-
BaseModel,
|
|
424
|
-
],
|
|
429
|
+
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
|
425
430
|
config: RunnableConfig,
|
|
426
431
|
*,
|
|
427
432
|
store: Optional[BaseStore], # noqa: UP045
|
|
@@ -436,11 +441,7 @@ class ToolNode(RunnableCallable):
|
|
|
436
441
|
|
|
437
442
|
async def _afunc(
|
|
438
443
|
self,
|
|
439
|
-
input:
|
|
440
|
-
list[AnyMessage],
|
|
441
|
-
dict[str, Any],
|
|
442
|
-
BaseModel,
|
|
443
|
-
],
|
|
444
|
+
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
|
444
445
|
config: RunnableConfig,
|
|
445
446
|
*,
|
|
446
447
|
store: Optional[BaseStore], # noqa: UP045
|
|
@@ -454,9 +455,9 @@ class ToolNode(RunnableCallable):
|
|
|
454
455
|
|
|
455
456
|
def _combine_tool_outputs(
|
|
456
457
|
self,
|
|
457
|
-
outputs: list[
|
|
458
|
+
outputs: list[ToolMessage | Command],
|
|
458
459
|
input_type: Literal["list", "dict", "tool_calls"],
|
|
459
|
-
) -> list[
|
|
460
|
+
) -> list[Command | list[ToolMessage] | dict[str, list[ToolMessage]]]:
|
|
460
461
|
# preserve existing behavior for non-command tool outputs for backwards
|
|
461
462
|
# compatibility
|
|
462
463
|
if not any(isinstance(output, Command) for output in outputs):
|
|
@@ -499,7 +500,7 @@ class ToolNode(RunnableCallable):
|
|
|
499
500
|
call: ToolCall,
|
|
500
501
|
input_type: Literal["list", "dict", "tool_calls"],
|
|
501
502
|
config: RunnableConfig,
|
|
502
|
-
) ->
|
|
503
|
+
) -> ToolMessage | Command:
|
|
503
504
|
"""Run a single tool call synchronously."""
|
|
504
505
|
if invalid_tool_message := self._validate_tool_call(call):
|
|
505
506
|
return invalid_tool_message
|
|
@@ -515,10 +516,12 @@ class ToolNode(RunnableCallable):
|
|
|
515
516
|
|
|
516
517
|
# GraphInterrupt is a special exception that will always be raised.
|
|
517
518
|
# It can be triggered in the following scenarios,
|
|
518
|
-
# Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation
|
|
519
|
+
# Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation
|
|
520
|
+
# most commonly:
|
|
519
521
|
# (1) a GraphInterrupt is raised inside a tool
|
|
520
522
|
# (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool
|
|
521
|
-
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph
|
|
523
|
+
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph
|
|
524
|
+
# called as a tool
|
|
522
525
|
# (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
|
|
523
526
|
except GraphBubbleUp:
|
|
524
527
|
raise
|
|
@@ -553,7 +556,7 @@ class ToolNode(RunnableCallable):
|
|
|
553
556
|
if isinstance(response, Command):
|
|
554
557
|
return self._validate_tool_command(response, call, input_type)
|
|
555
558
|
if isinstance(response, ToolMessage):
|
|
556
|
-
response.content = cast("
|
|
559
|
+
response.content = cast("str | list", msg_content_output(response.content))
|
|
557
560
|
return response
|
|
558
561
|
msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
|
|
559
562
|
raise TypeError(msg)
|
|
@@ -563,7 +566,7 @@ class ToolNode(RunnableCallable):
|
|
|
563
566
|
call: ToolCall,
|
|
564
567
|
input_type: Literal["list", "dict", "tool_calls"],
|
|
565
568
|
config: RunnableConfig,
|
|
566
|
-
) ->
|
|
569
|
+
) -> ToolMessage | Command:
|
|
567
570
|
"""Run a single tool call asynchronously."""
|
|
568
571
|
if invalid_tool_message := self._validate_tool_call(call):
|
|
569
572
|
return invalid_tool_message
|
|
@@ -579,10 +582,12 @@ class ToolNode(RunnableCallable):
|
|
|
579
582
|
|
|
580
583
|
# GraphInterrupt is a special exception that will always be raised.
|
|
581
584
|
# It can be triggered in the following scenarios,
|
|
582
|
-
# Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation
|
|
585
|
+
# Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation
|
|
586
|
+
# most commonly:
|
|
583
587
|
# (1) a GraphInterrupt is raised inside a tool
|
|
584
588
|
# (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool
|
|
585
|
-
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph
|
|
589
|
+
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph
|
|
590
|
+
# called as a tool
|
|
586
591
|
# (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
|
|
587
592
|
except GraphBubbleUp:
|
|
588
593
|
raise
|
|
@@ -618,18 +623,14 @@ class ToolNode(RunnableCallable):
|
|
|
618
623
|
if isinstance(response, Command):
|
|
619
624
|
return self._validate_tool_command(response, call, input_type)
|
|
620
625
|
if isinstance(response, ToolMessage):
|
|
621
|
-
response.content = cast("
|
|
626
|
+
response.content = cast("str | list", msg_content_output(response.content))
|
|
622
627
|
return response
|
|
623
628
|
msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
|
|
624
629
|
raise TypeError(msg)
|
|
625
630
|
|
|
626
631
|
def _parse_input(
|
|
627
632
|
self,
|
|
628
|
-
input:
|
|
629
|
-
list[AnyMessage],
|
|
630
|
-
dict[str, Any],
|
|
631
|
-
BaseModel,
|
|
632
|
-
],
|
|
633
|
+
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
|
633
634
|
store: BaseStore | None,
|
|
634
635
|
) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]:
|
|
635
636
|
input_type: Literal["list", "dict", "tool_calls"]
|
|
@@ -676,11 +677,7 @@ class ToolNode(RunnableCallable):
|
|
|
676
677
|
def _inject_state(
|
|
677
678
|
self,
|
|
678
679
|
tool_call: ToolCall,
|
|
679
|
-
input:
|
|
680
|
-
list[AnyMessage],
|
|
681
|
-
dict[str, Any],
|
|
682
|
-
BaseModel,
|
|
683
|
-
],
|
|
680
|
+
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
|
684
681
|
) -> ToolCall:
|
|
685
682
|
state_args = self._tool_to_state_args[tool_call["name"]]
|
|
686
683
|
if state_args and isinstance(input, list):
|
|
@@ -737,11 +734,7 @@ class ToolNode(RunnableCallable):
|
|
|
737
734
|
def inject_tool_args(
|
|
738
735
|
self,
|
|
739
736
|
tool_call: ToolCall,
|
|
740
|
-
input:
|
|
741
|
-
list[AnyMessage],
|
|
742
|
-
dict[str, Any],
|
|
743
|
-
BaseModel,
|
|
744
|
-
],
|
|
737
|
+
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
|
745
738
|
store: BaseStore | None,
|
|
746
739
|
) -> ToolCall:
|
|
747
740
|
"""Inject graph state and store into tool call arguments.
|
|
@@ -791,10 +784,12 @@ class ToolNode(RunnableCallable):
|
|
|
791
784
|
input_type: Literal["list", "dict", "tool_calls"],
|
|
792
785
|
) -> Command:
|
|
793
786
|
if isinstance(command.update, dict):
|
|
794
|
-
# input type is dict when ToolNode is invoked with a dict input
|
|
787
|
+
# input type is dict when ToolNode is invoked with a dict input
|
|
788
|
+
# (e.g. {"messages": [AIMessage(..., tool_calls=[...])]})
|
|
795
789
|
if input_type not in ("dict", "tool_calls"):
|
|
796
790
|
msg = (
|
|
797
|
-
|
|
791
|
+
"Tools can provide a dict in Command.update only when using dict "
|
|
792
|
+
f"with '{self._messages_key}' key as ToolNode input, "
|
|
798
793
|
f"got: {command.update} for tool '{call['name']}'"
|
|
799
794
|
)
|
|
800
795
|
raise ValueError(msg)
|
|
@@ -803,10 +798,12 @@ class ToolNode(RunnableCallable):
|
|
|
803
798
|
state_update = cast("dict[str, Any]", updated_command.update) or {}
|
|
804
799
|
messages_update = state_update.get(self._messages_key, [])
|
|
805
800
|
elif isinstance(command.update, list):
|
|
806
|
-
# Input type is list when ToolNode is invoked with a list input
|
|
801
|
+
# Input type is list when ToolNode is invoked with a list input
|
|
802
|
+
# (e.g. [AIMessage(..., tool_calls=[...])])
|
|
807
803
|
if input_type != "list":
|
|
808
804
|
msg = (
|
|
809
|
-
|
|
805
|
+
"Tools can provide a list of messages in Command.update "
|
|
806
|
+
"only when using list of messages as ToolNode input, "
|
|
810
807
|
f"got: {command.update} for tool '{call['name']}'"
|
|
811
808
|
)
|
|
812
809
|
raise ValueError(msg)
|
|
@@ -836,13 +833,17 @@ class ToolNode(RunnableCallable):
|
|
|
836
833
|
# Command.update if command is sent to the CURRENT graph
|
|
837
834
|
if updated_command.graph is None and not has_matching_tool_message:
|
|
838
835
|
example_update = (
|
|
839
|
-
'`Command(update={"messages":
|
|
836
|
+
'`Command(update={"messages": '
|
|
837
|
+
'[ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`'
|
|
840
838
|
if input_type == "dict"
|
|
841
|
-
else
|
|
839
|
+
else "`Command(update="
|
|
840
|
+
'[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`'
|
|
842
841
|
)
|
|
843
842
|
msg = (
|
|
844
|
-
|
|
845
|
-
"
|
|
843
|
+
"Expected to have a matching ToolMessage in Command.update "
|
|
844
|
+
f"for tool '{call['name']}', got: {messages_update}. "
|
|
845
|
+
"Every tool call (LLM requesting to call a tool) "
|
|
846
|
+
"in the message history MUST have a corresponding ToolMessage. "
|
|
846
847
|
f"You can fix it by modifying the tool to return {example_update}."
|
|
847
848
|
)
|
|
848
849
|
raise ValueError(msg)
|
|
@@ -850,7 +851,7 @@ class ToolNode(RunnableCallable):
|
|
|
850
851
|
|
|
851
852
|
|
|
852
853
|
def tools_condition(
|
|
853
|
-
state:
|
|
854
|
+
state: list[AnyMessage] | dict[str, Any] | BaseModel,
|
|
854
855
|
messages_key: str = "messages",
|
|
855
856
|
) -> Literal["tools", "__end__"]:
|
|
856
857
|
"""Conditional routing function for tool-calling workflows.
|
|
@@ -887,16 +888,18 @@ def tools_condition(
|
|
|
887
888
|
from langgraph.agents.tool_node import ToolNode, tools_condition
|
|
888
889
|
from typing_extensions import TypedDict
|
|
889
890
|
|
|
891
|
+
|
|
890
892
|
class State(TypedDict):
|
|
891
893
|
messages: list
|
|
892
894
|
|
|
895
|
+
|
|
893
896
|
graph = StateGraph(State)
|
|
894
897
|
graph.add_node("llm", call_model)
|
|
895
898
|
graph.add_node("tools", ToolNode([my_tool]))
|
|
896
899
|
graph.add_conditional_edges(
|
|
897
900
|
"llm",
|
|
898
901
|
tools_condition, # Routes to "tools" or "__end__"
|
|
899
|
-
{"tools": "tools", "__end__": "__end__"}
|
|
902
|
+
{"tools": "tools", "__end__": "__end__"},
|
|
900
903
|
)
|
|
901
904
|
```
|
|
902
905
|
|
|
@@ -956,6 +959,7 @@ class InjectedState(InjectedToolArg):
|
|
|
956
959
|
messages: List[BaseMessage]
|
|
957
960
|
foo: str
|
|
958
961
|
|
|
962
|
+
|
|
959
963
|
@tool
|
|
960
964
|
def state_tool(x: int, state: Annotated[dict, InjectedState]) -> str:
|
|
961
965
|
'''Do something with state.'''
|
|
@@ -964,11 +968,13 @@ class InjectedState(InjectedToolArg):
|
|
|
964
968
|
else:
|
|
965
969
|
return "not enough messages"
|
|
966
970
|
|
|
971
|
+
|
|
967
972
|
@tool
|
|
968
973
|
def foo_tool(x: int, foo: Annotated[str, InjectedState("foo")]) -> str:
|
|
969
974
|
'''Do something else with state.'''
|
|
970
975
|
return foo + str(x + 1)
|
|
971
976
|
|
|
977
|
+
|
|
972
978
|
node = ToolNode([state_tool, foo_tool])
|
|
973
979
|
|
|
974
980
|
tool_call1 = {"name": "state_tool", "args": {"x": 1}, "id": "1", "type": "tool_call"}
|
|
@@ -982,8 +988,8 @@ class InjectedState(InjectedToolArg):
|
|
|
982
988
|
|
|
983
989
|
```pycon
|
|
984
990
|
[
|
|
985
|
-
ToolMessage(content=
|
|
986
|
-
ToolMessage(content=
|
|
991
|
+
ToolMessage(content="not enough messages", name="state_tool", tool_call_id="1"),
|
|
992
|
+
ToolMessage(content="bar2", name="foo_tool", tool_call_id="2"),
|
|
987
993
|
]
|
|
988
994
|
```
|
|
989
995
|
|
|
@@ -1078,7 +1084,7 @@ class InjectedStore(InjectedToolArg):
|
|
|
1078
1084
|
"""
|
|
1079
1085
|
|
|
1080
1086
|
|
|
1081
|
-
def _is_injection(type_arg: Any, injection_type: type[
|
|
1087
|
+
def _is_injection(type_arg: Any, injection_type: type[InjectedState | InjectedStore]) -> bool:
|
|
1082
1088
|
"""Check if a type argument represents an injection annotation.
|
|
1083
1089
|
|
|
1084
1090
|
This utility function determines whether a type annotation indicates that
|
langchain/chat_models/base.py
CHANGED
|
@@ -9,7 +9,6 @@ from typing import (
|
|
|
9
9
|
Any,
|
|
10
10
|
Literal,
|
|
11
11
|
TypeAlias,
|
|
12
|
-
Union,
|
|
13
12
|
cast,
|
|
14
13
|
overload,
|
|
15
14
|
)
|
|
@@ -55,7 +54,7 @@ def init_chat_model(
|
|
|
55
54
|
model: str | None = None,
|
|
56
55
|
*,
|
|
57
56
|
model_provider: str | None = None,
|
|
58
|
-
configurable_fields:
|
|
57
|
+
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = ...,
|
|
59
58
|
config_prefix: str | None = None,
|
|
60
59
|
**kwargs: Any,
|
|
61
60
|
) -> _ConfigurableModel: ...
|
|
@@ -68,10 +67,10 @@ def init_chat_model(
|
|
|
68
67
|
model: str | None = None,
|
|
69
68
|
*,
|
|
70
69
|
model_provider: str | None = None,
|
|
71
|
-
configurable_fields:
|
|
70
|
+
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] | None = None,
|
|
72
71
|
config_prefix: str | None = None,
|
|
73
72
|
**kwargs: Any,
|
|
74
|
-
) ->
|
|
73
|
+
) -> BaseChatModel | _ConfigurableModel:
|
|
75
74
|
"""Initialize a ChatModel from the model name and provider.
|
|
76
75
|
|
|
77
76
|
**Note:** Must have the integration package corresponding to the model provider
|
|
@@ -191,14 +190,12 @@ def init_chat_model(
|
|
|
191
190
|
configurable_model = init_chat_model(temperature=0)
|
|
192
191
|
|
|
193
192
|
configurable_model.invoke(
|
|
194
|
-
"what's your name",
|
|
195
|
-
config={"configurable": {"model": "gpt-4o"}}
|
|
193
|
+
"what's your name", config={"configurable": {"model": "gpt-4o"}}
|
|
196
194
|
)
|
|
197
195
|
# GPT-4o response
|
|
198
196
|
|
|
199
197
|
configurable_model.invoke(
|
|
200
|
-
"what's your name",
|
|
201
|
-
config={"configurable": {"model": "claude-3-5-sonnet-latest"}}
|
|
198
|
+
"what's your name", config={"configurable": {"model": "claude-3-5-sonnet-latest"}}
|
|
202
199
|
)
|
|
203
200
|
# claude-3.5 sonnet response
|
|
204
201
|
|
|
@@ -213,7 +210,7 @@ def init_chat_model(
|
|
|
213
210
|
"openai:gpt-4o",
|
|
214
211
|
configurable_fields="any", # this allows us to configure other params like temperature, max_tokens, etc at runtime.
|
|
215
212
|
config_prefix="foo",
|
|
216
|
-
temperature=0
|
|
213
|
+
temperature=0,
|
|
217
214
|
)
|
|
218
215
|
|
|
219
216
|
configurable_model_with_default.invoke("what's your name")
|
|
@@ -224,9 +221,9 @@ def init_chat_model(
|
|
|
224
221
|
config={
|
|
225
222
|
"configurable": {
|
|
226
223
|
"foo_model": "anthropic:claude-3-5-sonnet-latest",
|
|
227
|
-
"foo_temperature": 0.6
|
|
224
|
+
"foo_temperature": 0.6,
|
|
228
225
|
}
|
|
229
|
-
}
|
|
226
|
+
},
|
|
230
227
|
)
|
|
231
228
|
# Claude-3.5 sonnet response with temperature 0.6
|
|
232
229
|
|
|
@@ -241,23 +238,26 @@ def init_chat_model(
|
|
|
241
238
|
from langchain.chat_models import init_chat_model
|
|
242
239
|
from pydantic import BaseModel, Field
|
|
243
240
|
|
|
241
|
+
|
|
244
242
|
class GetWeather(BaseModel):
|
|
245
243
|
'''Get the current weather in a given location'''
|
|
246
244
|
|
|
247
245
|
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
|
248
246
|
|
|
247
|
+
|
|
249
248
|
class GetPopulation(BaseModel):
|
|
250
249
|
'''Get the current population in a given location'''
|
|
251
250
|
|
|
252
251
|
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
|
253
252
|
|
|
253
|
+
|
|
254
254
|
configurable_model = init_chat_model(
|
|
255
|
-
"gpt-4o",
|
|
256
|
-
configurable_fields=("model", "model_provider"),
|
|
257
|
-
temperature=0
|
|
255
|
+
"gpt-4o", configurable_fields=("model", "model_provider"), temperature=0
|
|
258
256
|
)
|
|
259
257
|
|
|
260
|
-
configurable_model_with_tools = configurable_model.bind_tools(
|
|
258
|
+
configurable_model_with_tools = configurable_model.bind_tools(
|
|
259
|
+
[GetWeather, GetPopulation]
|
|
260
|
+
)
|
|
261
261
|
configurable_model_with_tools.invoke(
|
|
262
262
|
"Which city is hotter today and which is bigger: LA or NY?"
|
|
263
263
|
)
|
|
@@ -265,7 +265,7 @@ def init_chat_model(
|
|
|
265
265
|
|
|
266
266
|
configurable_model_with_tools.invoke(
|
|
267
267
|
"Which city is hotter today and which is bigger: LA or NY?",
|
|
268
|
-
config={"configurable": {"model": "claude-3-5-sonnet-latest"}}
|
|
268
|
+
config={"configurable": {"model": "claude-3-5-sonnet-latest"}},
|
|
269
269
|
)
|
|
270
270
|
# Claude-3.5 sonnet response with tools
|
|
271
271
|
|
|
@@ -530,12 +530,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
530
530
|
self,
|
|
531
531
|
*,
|
|
532
532
|
default_config: dict | None = None,
|
|
533
|
-
configurable_fields:
|
|
533
|
+
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = "any",
|
|
534
534
|
config_prefix: str = "",
|
|
535
535
|
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
|
|
536
536
|
) -> None:
|
|
537
537
|
self._default_config: dict = default_config or {}
|
|
538
|
-
self._configurable_fields:
|
|
538
|
+
self._configurable_fields: Literal["any"] | list[str] = (
|
|
539
539
|
configurable_fields if configurable_fields == "any" else list(configurable_fields)
|
|
540
540
|
)
|
|
541
541
|
self._config_prefix = (
|
|
@@ -638,11 +638,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
638
638
|
# This is a version of LanguageModelInput which replaces the abstract
|
|
639
639
|
# base class BaseMessage with a union of its subclasses, which makes
|
|
640
640
|
# for a much better schema.
|
|
641
|
-
return
|
|
642
|
-
str,
|
|
643
|
-
Union[StringPromptValue, ChatPromptValueConcrete],
|
|
644
|
-
list[AnyMessage],
|
|
645
|
-
]
|
|
641
|
+
return str | StringPromptValue | ChatPromptValueConcrete | list[AnyMessage]
|
|
646
642
|
|
|
647
643
|
@override
|
|
648
644
|
def invoke(
|
|
@@ -684,7 +680,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
684
680
|
def batch(
|
|
685
681
|
self,
|
|
686
682
|
inputs: list[LanguageModelInput],
|
|
687
|
-
config:
|
|
683
|
+
config: RunnableConfig | list[RunnableConfig] | None = None,
|
|
688
684
|
*,
|
|
689
685
|
return_exceptions: bool = False,
|
|
690
686
|
**kwargs: Any | None,
|
|
@@ -712,7 +708,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
712
708
|
async def abatch(
|
|
713
709
|
self,
|
|
714
710
|
inputs: list[LanguageModelInput],
|
|
715
|
-
config:
|
|
711
|
+
config: RunnableConfig | list[RunnableConfig] | None = None,
|
|
716
712
|
*,
|
|
717
713
|
return_exceptions: bool = False,
|
|
718
714
|
**kwargs: Any | None,
|
|
@@ -740,11 +736,11 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
740
736
|
def batch_as_completed(
|
|
741
737
|
self,
|
|
742
738
|
inputs: Sequence[LanguageModelInput],
|
|
743
|
-
config:
|
|
739
|
+
config: RunnableConfig | Sequence[RunnableConfig] | None = None,
|
|
744
740
|
*,
|
|
745
741
|
return_exceptions: bool = False,
|
|
746
742
|
**kwargs: Any,
|
|
747
|
-
) -> Iterator[tuple[int,
|
|
743
|
+
) -> Iterator[tuple[int, Any | Exception]]:
|
|
748
744
|
config = config or None
|
|
749
745
|
# If <= 1 config use the underlying models batch implementation.
|
|
750
746
|
if config is None or isinstance(config, dict) or len(config) <= 1:
|
|
@@ -769,7 +765,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
769
765
|
async def abatch_as_completed(
|
|
770
766
|
self,
|
|
771
767
|
inputs: Sequence[LanguageModelInput],
|
|
772
|
-
config:
|
|
768
|
+
config: RunnableConfig | Sequence[RunnableConfig] | None = None,
|
|
773
769
|
*,
|
|
774
770
|
return_exceptions: bool = False,
|
|
775
771
|
**kwargs: Any,
|
|
@@ -867,7 +863,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
867
863
|
exclude_types: Sequence[str] | None = None,
|
|
868
864
|
exclude_tags: Sequence[str] | None = None,
|
|
869
865
|
**kwargs: Any,
|
|
870
|
-
) ->
|
|
866
|
+
) -> AsyncIterator[RunLogPatch] | AsyncIterator[RunLog]:
|
|
871
867
|
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
|
|
872
868
|
input,
|
|
873
869
|
config=config,
|
|
@@ -915,7 +911,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
915
911
|
# Explicitly added to satisfy downstream linters.
|
|
916
912
|
def bind_tools(
|
|
917
913
|
self,
|
|
918
|
-
tools: Sequence[
|
|
914
|
+
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
|
|
919
915
|
**kwargs: Any,
|
|
920
916
|
) -> Runnable[LanguageModelInput, AIMessage]:
|
|
921
917
|
return self.__getattr__("bind_tools")(tools, **kwargs)
|
|
@@ -923,7 +919,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|
|
923
919
|
# Explicitly added to satisfy downstream linters.
|
|
924
920
|
def with_structured_output(
|
|
925
921
|
self,
|
|
926
|
-
schema:
|
|
922
|
+
schema: dict | type[BaseModel],
|
|
927
923
|
**kwargs: Any,
|
|
928
|
-
) -> Runnable[LanguageModelInput,
|
|
924
|
+
) -> Runnable[LanguageModelInput, dict | BaseModel]:
|
|
929
925
|
return self.__getattr__("with_structured_output")(schema, **kwargs)
|
langchain/embeddings/base.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
4
|
from importlib import util
|
|
5
|
-
from typing import Any
|
|
5
|
+
from typing import Any
|
|
6
6
|
|
|
7
7
|
from langchain_core.embeddings import Embeddings
|
|
8
8
|
from langchain_core.runnables import Runnable
|
|
@@ -126,7 +126,7 @@ def init_embeddings(
|
|
|
126
126
|
*,
|
|
127
127
|
provider: str | None = None,
|
|
128
128
|
**kwargs: Any,
|
|
129
|
-
) ->
|
|
129
|
+
) -> Embeddings | Runnable[Any, list[float]]:
|
|
130
130
|
"""Initialize an embeddings model from a model name and optional provider.
|
|
131
131
|
|
|
132
132
|
**Note:** Must have the integration package corresponding to the model provider
|
|
@@ -162,17 +162,11 @@ def init_embeddings(
|
|
|
162
162
|
model.embed_query("Hello, world!")
|
|
163
163
|
|
|
164
164
|
# Using explicit provider
|
|
165
|
-
model = init_embeddings(
|
|
166
|
-
model="text-embedding-3-small",
|
|
167
|
-
provider="openai"
|
|
168
|
-
)
|
|
165
|
+
model = init_embeddings(model="text-embedding-3-small", provider="openai")
|
|
169
166
|
model.embed_documents(["Hello, world!", "Goodbye, world!"])
|
|
170
167
|
|
|
171
168
|
# With additional parameters
|
|
172
|
-
model = init_embeddings(
|
|
173
|
-
"openai:text-embedding-3-small",
|
|
174
|
-
api_key="sk-..."
|
|
175
|
-
)
|
|
169
|
+
model = init_embeddings("openai:text-embedding-3-small", api_key="sk-...")
|
|
176
170
|
|
|
177
171
|
.. versionadded:: 0.3.9
|
|
178
172
|
|
langchain/embeddings/cache.py
CHANGED
|
@@ -13,7 +13,7 @@ import hashlib
|
|
|
13
13
|
import json
|
|
14
14
|
import uuid
|
|
15
15
|
import warnings
|
|
16
|
-
from typing import TYPE_CHECKING, Literal,
|
|
16
|
+
from typing import TYPE_CHECKING, Literal, cast
|
|
17
17
|
|
|
18
18
|
from langchain_core.embeddings import Embeddings
|
|
19
19
|
from langchain_core.utils.iter import batch_iterate
|
|
@@ -178,7 +178,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|
|
178
178
|
Returns:
|
|
179
179
|
A list of embeddings for the given texts.
|
|
180
180
|
"""
|
|
181
|
-
vectors: list[
|
|
181
|
+
vectors: list[list[float] | None] = self.document_embedding_store.mget(
|
|
182
182
|
texts,
|
|
183
183
|
)
|
|
184
184
|
all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
|
|
@@ -210,7 +210,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|
|
210
210
|
Returns:
|
|
211
211
|
A list of embeddings for the given texts.
|
|
212
212
|
"""
|
|
213
|
-
vectors: list[
|
|
213
|
+
vectors: list[list[float] | None] = await self.document_embedding_store.amget(texts)
|
|
214
214
|
all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
|
|
215
215
|
|
|
216
216
|
# batch_iterate supports None batch_size which returns all elements at once
|
|
@@ -285,11 +285,8 @@ class CacheBackedEmbeddings(Embeddings):
|
|
|
285
285
|
*,
|
|
286
286
|
namespace: str = "",
|
|
287
287
|
batch_size: int | None = None,
|
|
288
|
-
query_embedding_cache:
|
|
289
|
-
key_encoder:
|
|
290
|
-
Callable[[str], str],
|
|
291
|
-
Literal["sha1", "blake2b", "sha256", "sha512"],
|
|
292
|
-
] = "sha1",
|
|
288
|
+
query_embedding_cache: bool | ByteStore = False,
|
|
289
|
+
key_encoder: Callable[[str], str] | Literal["sha1", "blake2b", "sha256", "sha512"] = "sha1",
|
|
293
290
|
) -> CacheBackedEmbeddings:
|
|
294
291
|
"""On-ramp that adds the necessary serialization and encoding to the store.
|
|
295
292
|
|