langchain 1.0.0a4__py3-none-any.whl → 1.0.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,10 +23,12 @@ Typical Usage:
23
23
  from langchain_core.tools import tool
24
24
  from langchain.agents import ToolNode
25
25
 
26
+
26
27
  @tool
27
28
  def my_tool(x: int) -> str:
28
29
  return f"Result: {x}"
29
30
 
31
+
30
32
  tool_node = ToolNode([my_tool])
31
33
  ```
32
34
  """
@@ -38,6 +40,7 @@ import inspect
38
40
  import json
39
41
  from copy import copy, deepcopy
40
42
  from dataclasses import replace
43
+ from types import UnionType
41
44
  from typing import (
42
45
  TYPE_CHECKING,
43
46
  Annotated,
@@ -85,11 +88,19 @@ INVALID_TOOL_NAME_ERROR_TEMPLATE = (
85
88
  "Error: {requested_tool} is not a valid tool, try one of [{available_tools}]."
86
89
  )
87
90
  TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
88
- TOOL_EXECUTION_ERROR_TEMPLATE = "Error executing tool '{tool_name}' with kwargs {tool_kwargs} with error:\n {error}\n Please fix the error and try again."
89
- TOOL_INVOCATION_ERROR_TEMPLATE = "Error invoking tool '{tool_name}' with kwargs {tool_kwargs} with error:\n {error}\n Please fix the error and try again."
91
+ TOOL_EXECUTION_ERROR_TEMPLATE = (
92
+ "Error executing tool '{tool_name}' with kwargs {tool_kwargs} with error:\n"
93
+ " {error}\n"
94
+ " Please fix the error and try again."
95
+ )
96
+ TOOL_INVOCATION_ERROR_TEMPLATE = (
97
+ "Error invoking tool '{tool_name}' with kwargs {tool_kwargs} with error:\n"
98
+ " {error}\n"
99
+ " Please fix the error and try again."
100
+ )
90
101
 
91
102
 
92
- def msg_content_output(output: Any) -> Union[str, list[dict]]:
103
+ def msg_content_output(output: Any) -> str | list[dict]:
93
104
  """Convert tool output to valid message content format.
94
105
 
95
106
  LangChain ToolMessages accept either string content or a list of content blocks.
@@ -159,13 +170,7 @@ def _default_handle_tool_errors(e: Exception) -> str:
159
170
  def _handle_tool_error(
160
171
  e: Exception,
161
172
  *,
162
- flag: Union[
163
- bool,
164
- str,
165
- Callable[..., str],
166
- type[Exception],
167
- tuple[type[Exception], ...],
168
- ],
173
+ flag: bool | str | Callable[..., str] | type[Exception] | tuple[type[Exception], ...],
169
174
  ) -> str:
170
175
  """Generate error message content based on exception handling configuration.
171
176
 
@@ -242,7 +247,7 @@ def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception],
242
247
  type_hints = get_type_hints(handler)
243
248
  if first_param.name in type_hints:
244
249
  origin = get_origin(first_param.annotation)
245
- if origin is Union:
250
+ if origin in [Union, UnionType]:
246
251
  args = get_args(first_param.annotation)
247
252
  if all(issubclass(arg, Exception) for arg in args):
248
253
  return tuple(args)
@@ -317,7 +322,8 @@ class ToolNode(RunnableCallable):
317
322
  error template containing the exception details.
318
323
  - **str**: Catch all errors and return a ToolMessage with this custom
319
324
  error message string.
320
- - **type[Exception]**: Only catch exceptions with the specified type and return the default error message for it.
325
+ - **type[Exception]**: Only catch exceptions with the specified type and
326
+ return the default error message for it.
321
327
  - **tuple[type[Exception], ...]**: Only catch exceptions with the specified
322
328
  types and return default error messages for them.
323
329
  - **Callable[..., str]**: Catch exceptions matching the callable's signature
@@ -369,21 +375,24 @@ class ToolNode(RunnableCallable):
369
375
  def handle_errors(e: ValueError) -> str:
370
376
  return "Invalid input provided"
371
377
 
378
+
372
379
  tool_node = ToolNode([my_tool], handle_tool_errors=handle_errors)
373
380
  ```
374
- """
381
+ """ # noqa: E501
375
382
 
376
383
  name: str = "tools"
377
384
 
378
385
  def __init__(
379
386
  self,
380
- tools: Sequence[Union[BaseTool, Callable]],
387
+ tools: Sequence[BaseTool | Callable],
381
388
  *,
382
389
  name: str = "tools",
383
390
  tags: list[str] | None = None,
384
- handle_tool_errors: Union[
385
- bool, str, Callable[..., str], type[Exception], tuple[type[Exception], ...]
386
- ] = _default_handle_tool_errors,
391
+ handle_tool_errors: bool
392
+ | str
393
+ | Callable[..., str]
394
+ | type[Exception]
395
+ | tuple[type[Exception], ...] = _default_handle_tool_errors,
387
396
  messages_key: str = "messages",
388
397
  ) -> None:
389
398
  """Initialize the ToolNode with the provided tools and configuration.
@@ -417,11 +426,7 @@ class ToolNode(RunnableCallable):
417
426
 
418
427
  def _func(
419
428
  self,
420
- input: Union[
421
- list[AnyMessage],
422
- dict[str, Any],
423
- BaseModel,
424
- ],
429
+ input: list[AnyMessage] | dict[str, Any] | BaseModel,
425
430
  config: RunnableConfig,
426
431
  *,
427
432
  store: Optional[BaseStore], # noqa: UP045
@@ -436,11 +441,7 @@ class ToolNode(RunnableCallable):
436
441
 
437
442
  async def _afunc(
438
443
  self,
439
- input: Union[
440
- list[AnyMessage],
441
- dict[str, Any],
442
- BaseModel,
443
- ],
444
+ input: list[AnyMessage] | dict[str, Any] | BaseModel,
444
445
  config: RunnableConfig,
445
446
  *,
446
447
  store: Optional[BaseStore], # noqa: UP045
@@ -454,9 +455,9 @@ class ToolNode(RunnableCallable):
454
455
 
455
456
  def _combine_tool_outputs(
456
457
  self,
457
- outputs: list[Union[ToolMessage, Command]],
458
+ outputs: list[ToolMessage | Command],
458
459
  input_type: Literal["list", "dict", "tool_calls"],
459
- ) -> list[Union[Command, list[ToolMessage], dict[str, list[ToolMessage]]]]:
460
+ ) -> list[Command | list[ToolMessage] | dict[str, list[ToolMessage]]]:
460
461
  # preserve existing behavior for non-command tool outputs for backwards
461
462
  # compatibility
462
463
  if not any(isinstance(output, Command) for output in outputs):
@@ -499,7 +500,7 @@ class ToolNode(RunnableCallable):
499
500
  call: ToolCall,
500
501
  input_type: Literal["list", "dict", "tool_calls"],
501
502
  config: RunnableConfig,
502
- ) -> Union[ToolMessage, Command]:
503
+ ) -> ToolMessage | Command:
503
504
  """Run a single tool call synchronously."""
504
505
  if invalid_tool_message := self._validate_tool_call(call):
505
506
  return invalid_tool_message
@@ -515,10 +516,12 @@ class ToolNode(RunnableCallable):
515
516
 
516
517
  # GraphInterrupt is a special exception that will always be raised.
517
518
  # It can be triggered in the following scenarios,
518
- # Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation most commonly:
519
+ # Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation
520
+ # most commonly:
519
521
  # (1) a GraphInterrupt is raised inside a tool
520
522
  # (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool
521
- # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
523
+ # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph
524
+ # called as a tool
522
525
  # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
523
526
  except GraphBubbleUp:
524
527
  raise
@@ -553,7 +556,7 @@ class ToolNode(RunnableCallable):
553
556
  if isinstance(response, Command):
554
557
  return self._validate_tool_command(response, call, input_type)
555
558
  if isinstance(response, ToolMessage):
556
- response.content = cast("Union[str, list]", msg_content_output(response.content))
559
+ response.content = cast("str | list", msg_content_output(response.content))
557
560
  return response
558
561
  msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
559
562
  raise TypeError(msg)
@@ -563,7 +566,7 @@ class ToolNode(RunnableCallable):
563
566
  call: ToolCall,
564
567
  input_type: Literal["list", "dict", "tool_calls"],
565
568
  config: RunnableConfig,
566
- ) -> Union[ToolMessage, Command]:
569
+ ) -> ToolMessage | Command:
567
570
  """Run a single tool call asynchronously."""
568
571
  if invalid_tool_message := self._validate_tool_call(call):
569
572
  return invalid_tool_message
@@ -579,10 +582,12 @@ class ToolNode(RunnableCallable):
579
582
 
580
583
  # GraphInterrupt is a special exception that will always be raised.
581
584
  # It can be triggered in the following scenarios,
582
- # Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation most commonly:
585
+ # Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation
586
+ # most commonly:
583
587
  # (1) a GraphInterrupt is raised inside a tool
584
588
  # (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool
585
- # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
589
+ # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph
590
+ # called as a tool
586
591
  # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
587
592
  except GraphBubbleUp:
588
593
  raise
@@ -618,18 +623,14 @@ class ToolNode(RunnableCallable):
618
623
  if isinstance(response, Command):
619
624
  return self._validate_tool_command(response, call, input_type)
620
625
  if isinstance(response, ToolMessage):
621
- response.content = cast("Union[str, list]", msg_content_output(response.content))
626
+ response.content = cast("str | list", msg_content_output(response.content))
622
627
  return response
623
628
  msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
624
629
  raise TypeError(msg)
625
630
 
626
631
  def _parse_input(
627
632
  self,
628
- input: Union[
629
- list[AnyMessage],
630
- dict[str, Any],
631
- BaseModel,
632
- ],
633
+ input: list[AnyMessage] | dict[str, Any] | BaseModel,
633
634
  store: BaseStore | None,
634
635
  ) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]:
635
636
  input_type: Literal["list", "dict", "tool_calls"]
@@ -676,11 +677,7 @@ class ToolNode(RunnableCallable):
676
677
  def _inject_state(
677
678
  self,
678
679
  tool_call: ToolCall,
679
- input: Union[
680
- list[AnyMessage],
681
- dict[str, Any],
682
- BaseModel,
683
- ],
680
+ input: list[AnyMessage] | dict[str, Any] | BaseModel,
684
681
  ) -> ToolCall:
685
682
  state_args = self._tool_to_state_args[tool_call["name"]]
686
683
  if state_args and isinstance(input, list):
@@ -737,11 +734,7 @@ class ToolNode(RunnableCallable):
737
734
  def inject_tool_args(
738
735
  self,
739
736
  tool_call: ToolCall,
740
- input: Union[
741
- list[AnyMessage],
742
- dict[str, Any],
743
- BaseModel,
744
- ],
737
+ input: list[AnyMessage] | dict[str, Any] | BaseModel,
745
738
  store: BaseStore | None,
746
739
  ) -> ToolCall:
747
740
  """Inject graph state and store into tool call arguments.
@@ -791,10 +784,12 @@ class ToolNode(RunnableCallable):
791
784
  input_type: Literal["list", "dict", "tool_calls"],
792
785
  ) -> Command:
793
786
  if isinstance(command.update, dict):
794
- # input type is dict when ToolNode is invoked with a dict input (e.g. {"messages": [AIMessage(..., tool_calls=[...])]})
787
+ # input type is dict when ToolNode is invoked with a dict input
788
+ # (e.g. {"messages": [AIMessage(..., tool_calls=[...])]})
795
789
  if input_type not in ("dict", "tool_calls"):
796
790
  msg = (
797
- f"Tools can provide a dict in Command.update only when using dict with '{self._messages_key}' key as ToolNode input, "
791
+ "Tools can provide a dict in Command.update only when using dict "
792
+ f"with '{self._messages_key}' key as ToolNode input, "
798
793
  f"got: {command.update} for tool '{call['name']}'"
799
794
  )
800
795
  raise ValueError(msg)
@@ -803,10 +798,12 @@ class ToolNode(RunnableCallable):
803
798
  state_update = cast("dict[str, Any]", updated_command.update) or {}
804
799
  messages_update = state_update.get(self._messages_key, [])
805
800
  elif isinstance(command.update, list):
806
- # Input type is list when ToolNode is invoked with a list input (e.g. [AIMessage(..., tool_calls=[...])])
801
+ # Input type is list when ToolNode is invoked with a list input
802
+ # (e.g. [AIMessage(..., tool_calls=[...])])
807
803
  if input_type != "list":
808
804
  msg = (
809
- f"Tools can provide a list of messages in Command.update only when using list of messages as ToolNode input, "
805
+ "Tools can provide a list of messages in Command.update "
806
+ "only when using list of messages as ToolNode input, "
810
807
  f"got: {command.update} for tool '{call['name']}'"
811
808
  )
812
809
  raise ValueError(msg)
@@ -836,13 +833,17 @@ class ToolNode(RunnableCallable):
836
833
  # Command.update if command is sent to the CURRENT graph
837
834
  if updated_command.graph is None and not has_matching_tool_message:
838
835
  example_update = (
839
- '`Command(update={"messages": [ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`'
836
+ '`Command(update={"messages": '
837
+ '[ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`'
840
838
  if input_type == "dict"
841
- else '`Command(update=[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`'
839
+ else "`Command(update="
840
+ '[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`'
842
841
  )
843
842
  msg = (
844
- f"Expected to have a matching ToolMessage in Command.update for tool '{call['name']}', got: {messages_update}. "
845
- "Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage. "
843
+ "Expected to have a matching ToolMessage in Command.update "
844
+ f"for tool '{call['name']}', got: {messages_update}. "
845
+ "Every tool call (LLM requesting to call a tool) "
846
+ "in the message history MUST have a corresponding ToolMessage. "
846
847
  f"You can fix it by modifying the tool to return {example_update}."
847
848
  )
848
849
  raise ValueError(msg)
@@ -850,7 +851,7 @@ class ToolNode(RunnableCallable):
850
851
 
851
852
 
852
853
  def tools_condition(
853
- state: Union[list[AnyMessage], dict[str, Any], BaseModel],
854
+ state: list[AnyMessage] | dict[str, Any] | BaseModel,
854
855
  messages_key: str = "messages",
855
856
  ) -> Literal["tools", "__end__"]:
856
857
  """Conditional routing function for tool-calling workflows.
@@ -887,16 +888,18 @@ def tools_condition(
887
888
  from langgraph.agents.tool_node import ToolNode, tools_condition
888
889
  from typing_extensions import TypedDict
889
890
 
891
+
890
892
  class State(TypedDict):
891
893
  messages: list
892
894
 
895
+
893
896
  graph = StateGraph(State)
894
897
  graph.add_node("llm", call_model)
895
898
  graph.add_node("tools", ToolNode([my_tool]))
896
899
  graph.add_conditional_edges(
897
900
  "llm",
898
901
  tools_condition, # Routes to "tools" or "__end__"
899
- {"tools": "tools", "__end__": "__end__"}
902
+ {"tools": "tools", "__end__": "__end__"},
900
903
  )
901
904
  ```
902
905
 
@@ -956,6 +959,7 @@ class InjectedState(InjectedToolArg):
956
959
  messages: List[BaseMessage]
957
960
  foo: str
958
961
 
962
+
959
963
  @tool
960
964
  def state_tool(x: int, state: Annotated[dict, InjectedState]) -> str:
961
965
  '''Do something with state.'''
@@ -964,11 +968,13 @@ class InjectedState(InjectedToolArg):
964
968
  else:
965
969
  return "not enough messages"
966
970
 
971
+
967
972
  @tool
968
973
  def foo_tool(x: int, foo: Annotated[str, InjectedState("foo")]) -> str:
969
974
  '''Do something else with state.'''
970
975
  return foo + str(x + 1)
971
976
 
977
+
972
978
  node = ToolNode([state_tool, foo_tool])
973
979
 
974
980
  tool_call1 = {"name": "state_tool", "args": {"x": 1}, "id": "1", "type": "tool_call"}
@@ -982,8 +988,8 @@ class InjectedState(InjectedToolArg):
982
988
 
983
989
  ```pycon
984
990
  [
985
- ToolMessage(content='not enough messages', name='state_tool', tool_call_id='1'),
986
- ToolMessage(content='bar2', name='foo_tool', tool_call_id='2')
991
+ ToolMessage(content="not enough messages", name="state_tool", tool_call_id="1"),
992
+ ToolMessage(content="bar2", name="foo_tool", tool_call_id="2"),
987
993
  ]
988
994
  ```
989
995
 
@@ -1078,7 +1084,7 @@ class InjectedStore(InjectedToolArg):
1078
1084
  """
1079
1085
 
1080
1086
 
1081
- def _is_injection(type_arg: Any, injection_type: type[Union[InjectedState, InjectedStore]]) -> bool:
1087
+ def _is_injection(type_arg: Any, injection_type: type[InjectedState | InjectedStore]) -> bool:
1082
1088
  """Check if a type argument represents an injection annotation.
1083
1089
 
1084
1090
  This utility function determines whether a type annotation indicates that
@@ -9,7 +9,6 @@ from typing import (
9
9
  Any,
10
10
  Literal,
11
11
  TypeAlias,
12
- Union,
13
12
  cast,
14
13
  overload,
15
14
  )
@@ -55,7 +54,7 @@ def init_chat_model(
55
54
  model: str | None = None,
56
55
  *,
57
56
  model_provider: str | None = None,
58
- configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
57
+ configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = ...,
59
58
  config_prefix: str | None = None,
60
59
  **kwargs: Any,
61
60
  ) -> _ConfigurableModel: ...
@@ -68,10 +67,10 @@ def init_chat_model(
68
67
  model: str | None = None,
69
68
  *,
70
69
  model_provider: str | None = None,
71
- configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] | None = None,
70
+ configurable_fields: Literal["any"] | list[str] | tuple[str, ...] | None = None,
72
71
  config_prefix: str | None = None,
73
72
  **kwargs: Any,
74
- ) -> Union[BaseChatModel, _ConfigurableModel]:
73
+ ) -> BaseChatModel | _ConfigurableModel:
75
74
  """Initialize a ChatModel from the model name and provider.
76
75
 
77
76
  **Note:** Must have the integration package corresponding to the model provider
@@ -191,14 +190,12 @@ def init_chat_model(
191
190
  configurable_model = init_chat_model(temperature=0)
192
191
 
193
192
  configurable_model.invoke(
194
- "what's your name",
195
- config={"configurable": {"model": "gpt-4o"}}
193
+ "what's your name", config={"configurable": {"model": "gpt-4o"}}
196
194
  )
197
195
  # GPT-4o response
198
196
 
199
197
  configurable_model.invoke(
200
- "what's your name",
201
- config={"configurable": {"model": "claude-3-5-sonnet-latest"}}
198
+ "what's your name", config={"configurable": {"model": "claude-3-5-sonnet-latest"}}
202
199
  )
203
200
  # claude-3.5 sonnet response
204
201
 
@@ -213,7 +210,7 @@ def init_chat_model(
213
210
  "openai:gpt-4o",
214
211
  configurable_fields="any", # this allows us to configure other params like temperature, max_tokens, etc at runtime.
215
212
  config_prefix="foo",
216
- temperature=0
213
+ temperature=0,
217
214
  )
218
215
 
219
216
  configurable_model_with_default.invoke("what's your name")
@@ -224,9 +221,9 @@ def init_chat_model(
224
221
  config={
225
222
  "configurable": {
226
223
  "foo_model": "anthropic:claude-3-5-sonnet-latest",
227
- "foo_temperature": 0.6
224
+ "foo_temperature": 0.6,
228
225
  }
229
- }
226
+ },
230
227
  )
231
228
  # Claude-3.5 sonnet response with temperature 0.6
232
229
 
@@ -241,23 +238,26 @@ def init_chat_model(
241
238
  from langchain.chat_models import init_chat_model
242
239
  from pydantic import BaseModel, Field
243
240
 
241
+
244
242
  class GetWeather(BaseModel):
245
243
  '''Get the current weather in a given location'''
246
244
 
247
245
  location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
248
246
 
247
+
249
248
  class GetPopulation(BaseModel):
250
249
  '''Get the current population in a given location'''
251
250
 
252
251
  location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
253
252
 
253
+
254
254
  configurable_model = init_chat_model(
255
- "gpt-4o",
256
- configurable_fields=("model", "model_provider"),
257
- temperature=0
255
+ "gpt-4o", configurable_fields=("model", "model_provider"), temperature=0
258
256
  )
259
257
 
260
- configurable_model_with_tools = configurable_model.bind_tools([GetWeather, GetPopulation])
258
+ configurable_model_with_tools = configurable_model.bind_tools(
259
+ [GetWeather, GetPopulation]
260
+ )
261
261
  configurable_model_with_tools.invoke(
262
262
  "Which city is hotter today and which is bigger: LA or NY?"
263
263
  )
@@ -265,7 +265,7 @@ def init_chat_model(
265
265
 
266
266
  configurable_model_with_tools.invoke(
267
267
  "Which city is hotter today and which is bigger: LA or NY?",
268
- config={"configurable": {"model": "claude-3-5-sonnet-latest"}}
268
+ config={"configurable": {"model": "claude-3-5-sonnet-latest"}},
269
269
  )
270
270
  # Claude-3.5 sonnet response with tools
271
271
 
@@ -530,12 +530,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
530
530
  self,
531
531
  *,
532
532
  default_config: dict | None = None,
533
- configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any",
533
+ configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = "any",
534
534
  config_prefix: str = "",
535
535
  queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
536
536
  ) -> None:
537
537
  self._default_config: dict = default_config or {}
538
- self._configurable_fields: Union[Literal["any"], list[str]] = (
538
+ self._configurable_fields: Literal["any"] | list[str] = (
539
539
  configurable_fields if configurable_fields == "any" else list(configurable_fields)
540
540
  )
541
541
  self._config_prefix = (
@@ -638,11 +638,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
638
638
  # This is a version of LanguageModelInput which replaces the abstract
639
639
  # base class BaseMessage with a union of its subclasses, which makes
640
640
  # for a much better schema.
641
- return Union[
642
- str,
643
- Union[StringPromptValue, ChatPromptValueConcrete],
644
- list[AnyMessage],
645
- ]
641
+ return str | StringPromptValue | ChatPromptValueConcrete | list[AnyMessage]
646
642
 
647
643
  @override
648
644
  def invoke(
@@ -684,7 +680,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
684
680
  def batch(
685
681
  self,
686
682
  inputs: list[LanguageModelInput],
687
- config: Union[RunnableConfig, list[RunnableConfig]] | None = None,
683
+ config: RunnableConfig | list[RunnableConfig] | None = None,
688
684
  *,
689
685
  return_exceptions: bool = False,
690
686
  **kwargs: Any | None,
@@ -712,7 +708,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
712
708
  async def abatch(
713
709
  self,
714
710
  inputs: list[LanguageModelInput],
715
- config: Union[RunnableConfig, list[RunnableConfig]] | None = None,
711
+ config: RunnableConfig | list[RunnableConfig] | None = None,
716
712
  *,
717
713
  return_exceptions: bool = False,
718
714
  **kwargs: Any | None,
@@ -740,11 +736,11 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
740
736
  def batch_as_completed(
741
737
  self,
742
738
  inputs: Sequence[LanguageModelInput],
743
- config: Union[RunnableConfig, Sequence[RunnableConfig]] | None = None,
739
+ config: RunnableConfig | Sequence[RunnableConfig] | None = None,
744
740
  *,
745
741
  return_exceptions: bool = False,
746
742
  **kwargs: Any,
747
- ) -> Iterator[tuple[int, Union[Any, Exception]]]:
743
+ ) -> Iterator[tuple[int, Any | Exception]]:
748
744
  config = config or None
749
745
  # If <= 1 config use the underlying models batch implementation.
750
746
  if config is None or isinstance(config, dict) or len(config) <= 1:
@@ -769,7 +765,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
769
765
  async def abatch_as_completed(
770
766
  self,
771
767
  inputs: Sequence[LanguageModelInput],
772
- config: Union[RunnableConfig, Sequence[RunnableConfig]] | None = None,
768
+ config: RunnableConfig | Sequence[RunnableConfig] | None = None,
773
769
  *,
774
770
  return_exceptions: bool = False,
775
771
  **kwargs: Any,
@@ -867,7 +863,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
867
863
  exclude_types: Sequence[str] | None = None,
868
864
  exclude_tags: Sequence[str] | None = None,
869
865
  **kwargs: Any,
870
- ) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
866
+ ) -> AsyncIterator[RunLogPatch] | AsyncIterator[RunLog]:
871
867
  async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
872
868
  input,
873
869
  config=config,
@@ -915,7 +911,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
915
911
  # Explicitly added to satisfy downstream linters.
916
912
  def bind_tools(
917
913
  self,
918
- tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
914
+ tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
919
915
  **kwargs: Any,
920
916
  ) -> Runnable[LanguageModelInput, AIMessage]:
921
917
  return self.__getattr__("bind_tools")(tools, **kwargs)
@@ -923,7 +919,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
923
919
  # Explicitly added to satisfy downstream linters.
924
920
  def with_structured_output(
925
921
  self,
926
- schema: Union[dict, type[BaseModel]],
922
+ schema: dict | type[BaseModel],
927
923
  **kwargs: Any,
928
- ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
924
+ ) -> Runnable[LanguageModelInput, dict | BaseModel]:
929
925
  return self.__getattr__("with_structured_output")(schema, **kwargs)
@@ -2,7 +2,7 @@
2
2
 
3
3
  import functools
4
4
  from importlib import util
5
- from typing import Any, Union
5
+ from typing import Any
6
6
 
7
7
  from langchain_core.embeddings import Embeddings
8
8
  from langchain_core.runnables import Runnable
@@ -126,7 +126,7 @@ def init_embeddings(
126
126
  *,
127
127
  provider: str | None = None,
128
128
  **kwargs: Any,
129
- ) -> Union[Embeddings, Runnable[Any, list[float]]]:
129
+ ) -> Embeddings | Runnable[Any, list[float]]:
130
130
  """Initialize an embeddings model from a model name and optional provider.
131
131
 
132
132
  **Note:** Must have the integration package corresponding to the model provider
@@ -162,17 +162,11 @@ def init_embeddings(
162
162
  model.embed_query("Hello, world!")
163
163
 
164
164
  # Using explicit provider
165
- model = init_embeddings(
166
- model="text-embedding-3-small",
167
- provider="openai"
168
- )
165
+ model = init_embeddings(model="text-embedding-3-small", provider="openai")
169
166
  model.embed_documents(["Hello, world!", "Goodbye, world!"])
170
167
 
171
168
  # With additional parameters
172
- model = init_embeddings(
173
- "openai:text-embedding-3-small",
174
- api_key="sk-..."
175
- )
169
+ model = init_embeddings("openai:text-embedding-3-small", api_key="sk-...")
176
170
 
177
171
  .. versionadded:: 0.3.9
178
172
 
@@ -13,7 +13,7 @@ import hashlib
13
13
  import json
14
14
  import uuid
15
15
  import warnings
16
- from typing import TYPE_CHECKING, Literal, Union, cast
16
+ from typing import TYPE_CHECKING, Literal, cast
17
17
 
18
18
  from langchain_core.embeddings import Embeddings
19
19
  from langchain_core.utils.iter import batch_iterate
@@ -178,7 +178,7 @@ class CacheBackedEmbeddings(Embeddings):
178
178
  Returns:
179
179
  A list of embeddings for the given texts.
180
180
  """
181
- vectors: list[Union[list[float], None]] = self.document_embedding_store.mget(
181
+ vectors: list[list[float] | None] = self.document_embedding_store.mget(
182
182
  texts,
183
183
  )
184
184
  all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
@@ -210,7 +210,7 @@ class CacheBackedEmbeddings(Embeddings):
210
210
  Returns:
211
211
  A list of embeddings for the given texts.
212
212
  """
213
- vectors: list[Union[list[float], None]] = await self.document_embedding_store.amget(texts)
213
+ vectors: list[list[float] | None] = await self.document_embedding_store.amget(texts)
214
214
  all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
215
215
 
216
216
  # batch_iterate supports None batch_size which returns all elements at once
@@ -285,11 +285,8 @@ class CacheBackedEmbeddings(Embeddings):
285
285
  *,
286
286
  namespace: str = "",
287
287
  batch_size: int | None = None,
288
- query_embedding_cache: Union[bool, ByteStore] = False,
289
- key_encoder: Union[
290
- Callable[[str], str],
291
- Literal["sha1", "blake2b", "sha256", "sha512"],
292
- ] = "sha1",
288
+ query_embedding_cache: bool | ByteStore = False,
289
+ key_encoder: Callable[[str], str] | Literal["sha1", "blake2b", "sha256", "sha512"] = "sha1",
293
290
  ) -> CacheBackedEmbeddings:
294
291
  """On-ramp that adds the necessary serialization and encoding to the store.
295
292