langchain-core 1.0.0a6__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. langchain_core/__init__.py +1 -1
  2. langchain_core/_api/__init__.py +3 -4
  3. langchain_core/_api/beta_decorator.py +23 -26
  4. langchain_core/_api/deprecation.py +51 -64
  5. langchain_core/_api/path.py +3 -6
  6. langchain_core/_import_utils.py +3 -4
  7. langchain_core/agents.py +55 -48
  8. langchain_core/caches.py +65 -66
  9. langchain_core/callbacks/__init__.py +1 -8
  10. langchain_core/callbacks/base.py +321 -336
  11. langchain_core/callbacks/file.py +44 -44
  12. langchain_core/callbacks/manager.py +454 -514
  13. langchain_core/callbacks/stdout.py +29 -30
  14. langchain_core/callbacks/streaming_stdout.py +32 -32
  15. langchain_core/callbacks/usage.py +60 -57
  16. langchain_core/chat_history.py +53 -68
  17. langchain_core/document_loaders/base.py +27 -25
  18. langchain_core/document_loaders/blob_loaders.py +1 -1
  19. langchain_core/document_loaders/langsmith.py +44 -48
  20. langchain_core/documents/__init__.py +23 -3
  21. langchain_core/documents/base.py +102 -94
  22. langchain_core/documents/compressor.py +10 -10
  23. langchain_core/documents/transformers.py +34 -35
  24. langchain_core/embeddings/fake.py +50 -54
  25. langchain_core/example_selectors/length_based.py +2 -2
  26. langchain_core/example_selectors/semantic_similarity.py +28 -32
  27. langchain_core/exceptions.py +21 -20
  28. langchain_core/globals.py +3 -151
  29. langchain_core/indexing/__init__.py +1 -1
  30. langchain_core/indexing/api.py +121 -126
  31. langchain_core/indexing/base.py +73 -75
  32. langchain_core/indexing/in_memory.py +4 -6
  33. langchain_core/language_models/__init__.py +14 -29
  34. langchain_core/language_models/_utils.py +58 -61
  35. langchain_core/language_models/base.py +82 -172
  36. langchain_core/language_models/chat_models.py +329 -402
  37. langchain_core/language_models/fake.py +11 -11
  38. langchain_core/language_models/fake_chat_models.py +42 -36
  39. langchain_core/language_models/llms.py +189 -269
  40. langchain_core/load/dump.py +9 -12
  41. langchain_core/load/load.py +18 -28
  42. langchain_core/load/mapping.py +2 -4
  43. langchain_core/load/serializable.py +42 -40
  44. langchain_core/messages/__init__.py +10 -16
  45. langchain_core/messages/ai.py +148 -148
  46. langchain_core/messages/base.py +53 -51
  47. langchain_core/messages/block_translators/__init__.py +19 -22
  48. langchain_core/messages/block_translators/anthropic.py +6 -6
  49. langchain_core/messages/block_translators/bedrock_converse.py +5 -5
  50. langchain_core/messages/block_translators/google_genai.py +10 -7
  51. langchain_core/messages/block_translators/google_vertexai.py +4 -32
  52. langchain_core/messages/block_translators/groq.py +117 -21
  53. langchain_core/messages/block_translators/langchain_v0.py +5 -5
  54. langchain_core/messages/block_translators/openai.py +11 -11
  55. langchain_core/messages/chat.py +2 -6
  56. langchain_core/messages/content.py +339 -330
  57. langchain_core/messages/function.py +6 -10
  58. langchain_core/messages/human.py +24 -31
  59. langchain_core/messages/modifier.py +2 -2
  60. langchain_core/messages/system.py +19 -29
  61. langchain_core/messages/tool.py +74 -90
  62. langchain_core/messages/utils.py +484 -510
  63. langchain_core/output_parsers/__init__.py +13 -10
  64. langchain_core/output_parsers/base.py +61 -61
  65. langchain_core/output_parsers/format_instructions.py +9 -4
  66. langchain_core/output_parsers/json.py +12 -10
  67. langchain_core/output_parsers/list.py +21 -23
  68. langchain_core/output_parsers/openai_functions.py +49 -47
  69. langchain_core/output_parsers/openai_tools.py +30 -23
  70. langchain_core/output_parsers/pydantic.py +13 -14
  71. langchain_core/output_parsers/string.py +5 -5
  72. langchain_core/output_parsers/transform.py +15 -17
  73. langchain_core/output_parsers/xml.py +35 -34
  74. langchain_core/outputs/__init__.py +1 -1
  75. langchain_core/outputs/chat_generation.py +18 -18
  76. langchain_core/outputs/chat_result.py +1 -3
  77. langchain_core/outputs/generation.py +16 -16
  78. langchain_core/outputs/llm_result.py +10 -10
  79. langchain_core/prompt_values.py +13 -19
  80. langchain_core/prompts/__init__.py +3 -27
  81. langchain_core/prompts/base.py +81 -86
  82. langchain_core/prompts/chat.py +308 -351
  83. langchain_core/prompts/dict.py +6 -6
  84. langchain_core/prompts/few_shot.py +81 -88
  85. langchain_core/prompts/few_shot_with_templates.py +11 -13
  86. langchain_core/prompts/image.py +12 -14
  87. langchain_core/prompts/loading.py +4 -6
  88. langchain_core/prompts/message.py +7 -7
  89. langchain_core/prompts/prompt.py +24 -39
  90. langchain_core/prompts/string.py +26 -10
  91. langchain_core/prompts/structured.py +49 -53
  92. langchain_core/rate_limiters.py +51 -60
  93. langchain_core/retrievers.py +61 -198
  94. langchain_core/runnables/base.py +1551 -1656
  95. langchain_core/runnables/branch.py +68 -70
  96. langchain_core/runnables/config.py +72 -89
  97. langchain_core/runnables/configurable.py +145 -161
  98. langchain_core/runnables/fallbacks.py +102 -96
  99. langchain_core/runnables/graph.py +91 -97
  100. langchain_core/runnables/graph_ascii.py +27 -28
  101. langchain_core/runnables/graph_mermaid.py +42 -51
  102. langchain_core/runnables/graph_png.py +43 -16
  103. langchain_core/runnables/history.py +175 -177
  104. langchain_core/runnables/passthrough.py +151 -167
  105. langchain_core/runnables/retry.py +46 -51
  106. langchain_core/runnables/router.py +30 -35
  107. langchain_core/runnables/schema.py +75 -80
  108. langchain_core/runnables/utils.py +60 -67
  109. langchain_core/stores.py +85 -121
  110. langchain_core/structured_query.py +8 -8
  111. langchain_core/sys_info.py +29 -29
  112. langchain_core/tools/__init__.py +1 -14
  113. langchain_core/tools/base.py +306 -245
  114. langchain_core/tools/convert.py +160 -155
  115. langchain_core/tools/render.py +10 -10
  116. langchain_core/tools/retriever.py +12 -11
  117. langchain_core/tools/simple.py +19 -24
  118. langchain_core/tools/structured.py +32 -39
  119. langchain_core/tracers/__init__.py +1 -9
  120. langchain_core/tracers/base.py +97 -99
  121. langchain_core/tracers/context.py +29 -52
  122. langchain_core/tracers/core.py +49 -53
  123. langchain_core/tracers/evaluation.py +11 -11
  124. langchain_core/tracers/event_stream.py +65 -64
  125. langchain_core/tracers/langchain.py +21 -21
  126. langchain_core/tracers/log_stream.py +45 -45
  127. langchain_core/tracers/memory_stream.py +3 -3
  128. langchain_core/tracers/root_listeners.py +16 -16
  129. langchain_core/tracers/run_collector.py +2 -4
  130. langchain_core/tracers/schemas.py +0 -129
  131. langchain_core/tracers/stdout.py +3 -3
  132. langchain_core/utils/__init__.py +1 -4
  133. langchain_core/utils/_merge.py +2 -2
  134. langchain_core/utils/aiter.py +57 -61
  135. langchain_core/utils/env.py +9 -9
  136. langchain_core/utils/function_calling.py +94 -188
  137. langchain_core/utils/html.py +7 -8
  138. langchain_core/utils/input.py +9 -6
  139. langchain_core/utils/interactive_env.py +1 -1
  140. langchain_core/utils/iter.py +36 -40
  141. langchain_core/utils/json.py +4 -3
  142. langchain_core/utils/json_schema.py +9 -9
  143. langchain_core/utils/mustache.py +8 -10
  144. langchain_core/utils/pydantic.py +35 -37
  145. langchain_core/utils/strings.py +6 -9
  146. langchain_core/utils/usage.py +1 -1
  147. langchain_core/utils/utils.py +66 -62
  148. langchain_core/vectorstores/base.py +182 -216
  149. langchain_core/vectorstores/in_memory.py +101 -176
  150. langchain_core/vectorstores/utils.py +5 -5
  151. langchain_core/version.py +1 -1
  152. langchain_core-1.0.4.dist-info/METADATA +69 -0
  153. langchain_core-1.0.4.dist-info/RECORD +172 -0
  154. {langchain_core-1.0.0a6.dist-info → langchain_core-1.0.4.dist-info}/WHEEL +1 -1
  155. langchain_core/memory.py +0 -120
  156. langchain_core/messages/block_translators/ollama.py +0 -47
  157. langchain_core/prompts/pipeline.py +0 -138
  158. langchain_core/pydantic_v1/__init__.py +0 -30
  159. langchain_core/pydantic_v1/dataclasses.py +0 -23
  160. langchain_core/pydantic_v1/main.py +0 -23
  161. langchain_core/tracers/langchain_v1.py +0 -31
  162. langchain_core/utils/loading.py +0 -35
  163. langchain_core-1.0.0a6.dist-info/METADATA +0 -67
  164. langchain_core-1.0.0a6.dist-info/RECORD +0 -181
  165. langchain_core-1.0.0a6.dist-info/entry_points.txt +0 -4
@@ -1,11 +1,15 @@
1
1
  """Runnable that selects which branch to run based on a condition."""
2
2
 
3
- from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence
3
+ from collections.abc import (
4
+ AsyncIterator,
5
+ Awaitable,
6
+ Callable,
7
+ Iterator,
8
+ Mapping,
9
+ Sequence,
10
+ )
4
11
  from typing import (
5
12
  Any,
6
- Callable,
7
- Optional,
8
- Union,
9
13
  cast,
10
14
  )
11
15
 
@@ -32,68 +36,64 @@ from langchain_core.runnables.utils import (
32
36
  get_unique_config_specs,
33
37
  )
34
38
 
39
+ _MIN_BRANCHES = 2
40
+
35
41
 
36
42
  class RunnableBranch(RunnableSerializable[Input, Output]):
37
- """Runnable that selects which branch to run based on a condition.
43
+ """`Runnable` that selects which branch to run based on a condition.
38
44
 
39
- The Runnable is initialized with a list of (condition, Runnable) pairs and
45
+ The `Runnable` is initialized with a list of `(condition, Runnable)` pairs and
40
46
  a default branch.
41
47
 
42
48
  When operating on an input, the first condition that evaluates to True is
43
- selected, and the corresponding Runnable is run on the input.
49
+ selected, and the corresponding `Runnable` is run on the input.
44
50
 
45
- If no condition evaluates to True, the default branch is run on the input.
51
+ If no condition evaluates to `True`, the default branch is run on the input.
46
52
 
47
53
  Examples:
54
+ ```python
55
+ from langchain_core.runnables import RunnableBranch
56
+
57
+ branch = RunnableBranch(
58
+ (lambda x: isinstance(x, str), lambda x: x.upper()),
59
+ (lambda x: isinstance(x, int), lambda x: x + 1),
60
+ (lambda x: isinstance(x, float), lambda x: x * 2),
61
+ lambda x: "goodbye",
62
+ )
48
63
 
49
- .. code-block:: python
50
-
51
- from langchain_core.runnables import RunnableBranch
52
-
53
- branch = RunnableBranch(
54
- (lambda x: isinstance(x, str), lambda x: x.upper()),
55
- (lambda x: isinstance(x, int), lambda x: x + 1),
56
- (lambda x: isinstance(x, float), lambda x: x * 2),
57
- lambda x: "goodbye",
58
- )
59
-
60
- branch.invoke("hello") # "HELLO"
61
- branch.invoke(None) # "goodbye"
62
-
64
+ branch.invoke("hello") # "HELLO"
65
+ branch.invoke(None) # "goodbye"
66
+ ```
63
67
  """
64
68
 
65
69
  branches: Sequence[tuple[Runnable[Input, bool], Runnable[Input, Output]]]
66
- """A list of (condition, Runnable) pairs."""
70
+ """A list of `(condition, Runnable)` pairs."""
67
71
  default: Runnable[Input, Output]
68
- """A Runnable to run if no condition is met."""
72
+ """A `Runnable` to run if no condition is met."""
69
73
 
70
74
  def __init__(
71
75
  self,
72
- *branches: Union[
73
- tuple[
74
- Union[
75
- Runnable[Input, bool],
76
- Callable[[Input], bool],
77
- Callable[[Input], Awaitable[bool]],
78
- ],
79
- RunnableLike,
80
- ],
81
- RunnableLike, # To accommodate the default branch
82
- ],
76
+ *branches: tuple[
77
+ Runnable[Input, bool]
78
+ | Callable[[Input], bool]
79
+ | Callable[[Input], Awaitable[bool]],
80
+ RunnableLike,
81
+ ]
82
+ | RunnableLike,
83
83
  ) -> None:
84
- """A Runnable that runs one of two branches based on a condition.
84
+ """A `Runnable` that runs one of two branches based on a condition.
85
85
 
86
86
  Args:
87
- *branches: A list of (condition, Runnable) pairs.
88
- Defaults a Runnable to run if no condition is met.
87
+ *branches: A list of `(condition, Runnable)` pairs.
88
+ Defaults a `Runnable` to run if no condition is met.
89
89
 
90
90
  Raises:
91
- ValueError: If the number of branches is less than 2.
92
- TypeError: If the default branch is not Runnable, Callable or Mapping.
93
- TypeError: If a branch is not a tuple or list.
94
- ValueError: If a branch is not of length 2.
91
+ ValueError: If the number of branches is less than `2`.
92
+ TypeError: If the default branch is not `Runnable`, `Callable` or `Mapping`.
93
+ TypeError: If a branch is not a `tuple` or `list`.
94
+ ValueError: If a branch is not of length `2`.
95
95
  """
96
- if len(branches) < 2:
96
+ if len(branches) < _MIN_BRANCHES:
97
97
  msg = "RunnableBranch requires at least two branches"
98
98
  raise ValueError(msg)
99
99
 
@@ -120,7 +120,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
120
120
  )
121
121
  raise TypeError(msg)
122
122
 
123
- if len(branch) != 2:
123
+ if len(branch) != _MIN_BRANCHES:
124
124
  msg = (
125
125
  f"RunnableBranch branches must be "
126
126
  f"tuples or lists of length 2, not {len(branch)}"
@@ -142,23 +142,21 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
142
142
 
143
143
  @classmethod
144
144
  def is_lc_serializable(cls) -> bool:
145
- """Return True as this class is serializable."""
145
+ """Return `True` as this class is serializable."""
146
146
  return True
147
147
 
148
148
  @classmethod
149
149
  @override
150
150
  def get_lc_namespace(cls) -> list[str]:
151
- """Get the namespace of the langchain object.
151
+ """Get the namespace of the LangChain object.
152
152
 
153
153
  Returns:
154
- ``["langchain", "schema", "runnable"]``
154
+ `["langchain", "schema", "runnable"]`
155
155
  """
156
156
  return ["langchain", "schema", "runnable"]
157
157
 
158
158
  @override
159
- def get_input_schema(
160
- self, config: Optional[RunnableConfig] = None
161
- ) -> type[BaseModel]:
159
+ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
162
160
  runnables = (
163
161
  [self.default]
164
162
  + [r for _, r in self.branches]
@@ -189,14 +187,14 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
189
187
 
190
188
  @override
191
189
  def invoke(
192
- self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
190
+ self, input: Input, config: RunnableConfig | None = None, **kwargs: Any
193
191
  ) -> Output:
194
- """First evaluates the condition, then delegate to true or false branch.
192
+ """First evaluates the condition, then delegate to `True` or `False` branch.
195
193
 
196
194
  Args:
197
- input: The input to the Runnable.
198
- config: The configuration for the Runnable. Defaults to None.
199
- kwargs: Additional keyword arguments to pass to the Runnable.
195
+ input: The input to the `Runnable`.
196
+ config: The configuration for the `Runnable`.
197
+ **kwargs: Additional keyword arguments to pass to the `Runnable`.
200
198
 
201
199
  Returns:
202
200
  The output of the branch that was run.
@@ -248,7 +246,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
248
246
 
249
247
  @override
250
248
  async def ainvoke(
251
- self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
249
+ self, input: Input, config: RunnableConfig | None = None, **kwargs: Any
252
250
  ) -> Output:
253
251
  config = ensure_config(config)
254
252
  callback_manager = get_async_callback_manager_for_config(config)
@@ -298,15 +296,15 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
298
296
  def stream(
299
297
  self,
300
298
  input: Input,
301
- config: Optional[RunnableConfig] = None,
302
- **kwargs: Optional[Any],
299
+ config: RunnableConfig | None = None,
300
+ **kwargs: Any | None,
303
301
  ) -> Iterator[Output]:
304
- """First evaluates the condition, then delegate to true or false branch.
302
+ """First evaluates the condition, then delegate to `True` or `False` branch.
305
303
 
306
304
  Args:
307
- input: The input to the Runnable.
308
- config: The configuration for the Runnable. Defaults to None.
309
- kwargs: Additional keyword arguments to pass to the Runnable.
305
+ input: The input to the `Runnable`.
306
+ config: The configuration for the Runna`ble.
307
+ **kwargs: Additional keyword arguments to pass to the `Runnable`.
310
308
 
311
309
  Yields:
312
310
  The output of the branch that was run.
@@ -319,7 +317,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
319
317
  name=config.get("run_name") or self.get_name(),
320
318
  run_id=config.pop("run_id", None),
321
319
  )
322
- final_output: Optional[Output] = None
320
+ final_output: Output | None = None
323
321
  final_output_supported = True
324
322
 
325
323
  try:
@@ -382,15 +380,15 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
382
380
  async def astream(
383
381
  self,
384
382
  input: Input,
385
- config: Optional[RunnableConfig] = None,
386
- **kwargs: Optional[Any],
383
+ config: RunnableConfig | None = None,
384
+ **kwargs: Any | None,
387
385
  ) -> AsyncIterator[Output]:
388
- """First evaluates the condition, then delegate to true or false branch.
386
+ """First evaluates the condition, then delegate to `True` or `False` branch.
389
387
 
390
388
  Args:
391
- input: The input to the Runnable.
392
- config: The configuration for the Runnable. Defaults to None.
393
- kwargs: Additional keyword arguments to pass to the Runnable.
389
+ input: The input to the `Runnable`.
390
+ config: The configuration for the `Runnable`.
391
+ **kwargs: Additional keyword arguments to pass to the `Runnable`.
394
392
 
395
393
  Yields:
396
394
  The output of the branch that was run.
@@ -403,7 +401,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
403
401
  name=config.get("run_name") or self.get_name(),
404
402
  run_id=config.pop("run_id", None),
405
403
  )
406
- final_output: Optional[Output] = None
404
+ final_output: Output | None = None
407
405
  final_output_supported = True
408
406
 
409
407
  try:
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
  import asyncio
6
6
  import uuid
7
7
  import warnings
8
- from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence
8
+ from collections.abc import Awaitable, Callable, Generator, Iterable, Iterator, Sequence
9
9
  from concurrent.futures import Executor, Future, ThreadPoolExecutor
10
10
  from contextlib import contextmanager
11
11
  from contextvars import Context, ContextVar, Token, copy_context
@@ -13,11 +13,8 @@ from functools import partial
13
13
  from typing import (
14
14
  TYPE_CHECKING,
15
15
  Any,
16
- Callable,
17
- Optional,
18
16
  ParamSpec,
19
17
  TypeVar,
20
- Union,
21
18
  cast,
22
19
  )
23
20
 
@@ -42,7 +39,7 @@ if TYPE_CHECKING:
42
39
  else:
43
40
  # Pydantic validates through typed dicts, but
44
41
  # the callbacks need forward refs updated
45
- Callbacks = Optional[Union[list, Any]]
42
+ Callbacks = list | Any | None
46
43
 
47
44
 
48
45
  class EmptyDict(TypedDict, total=False):
@@ -75,29 +72,29 @@ class RunnableConfig(TypedDict, total=False):
75
72
  Name for the tracer run for this call. Defaults to the name of the class.
76
73
  """
77
74
 
78
- max_concurrency: Optional[int]
75
+ max_concurrency: int | None
79
76
  """
80
77
  Maximum number of parallel calls to make. If not provided, defaults to
81
- ThreadPoolExecutor's default.
78
+ `ThreadPoolExecutor`'s default.
82
79
  """
83
80
 
84
81
  recursion_limit: int
85
82
  """
86
- Maximum number of times a call can recurse. If not provided, defaults to 25.
83
+ Maximum number of times a call can recurse. If not provided, defaults to `25`.
87
84
  """
88
85
 
89
86
  configurable: dict[str, Any]
90
87
  """
91
- Runtime values for attributes previously made configurable on this Runnable,
92
- or sub-Runnables, through .configurable_fields() or .configurable_alternatives().
93
- Check .output_schema() for a description of the attributes that have been made
88
+ Runtime values for attributes previously made configurable on this `Runnable`,
89
+ or sub-Runnables, through `configurable_fields` or `configurable_alternatives`.
90
+ Check `output_schema` for a description of the attributes that have been made
94
91
  configurable.
95
92
  """
96
93
 
97
- run_id: Optional[uuid.UUID]
94
+ run_id: uuid.UUID | None
98
95
  """
99
96
  Unique identifier for the tracer run for this call. If not provided, a new UUID
100
- will be generated.
97
+ will be generated.
101
98
  """
102
99
 
103
100
 
@@ -130,11 +127,11 @@ var_child_runnable_config: ContextVar[RunnableConfig | None] = ContextVar(
130
127
  # This is imported and used in langgraph, so don't break.
131
128
  def _set_config_context(
132
129
  config: RunnableConfig,
133
- ) -> tuple[Token[Optional[RunnableConfig]], Optional[dict[str, Any]]]:
130
+ ) -> tuple[Token[RunnableConfig | None], dict[str, Any] | None]:
134
131
  """Set the child Runnable config + tracing context.
135
132
 
136
133
  Args:
137
- config (RunnableConfig): The config to set.
134
+ config: The config to set.
138
135
 
139
136
  Returns:
140
137
  The token to reset the config and the previous tracing context.
@@ -168,7 +165,7 @@ def set_config_context(config: RunnableConfig) -> Generator[Context, None, None]
168
165
  """Set the child Runnable config + tracing context.
169
166
 
170
167
  Args:
171
- config (RunnableConfig): The config to set.
168
+ config: The config to set.
172
169
 
173
170
  Yields:
174
171
  The config context.
@@ -192,15 +189,14 @@ def set_config_context(config: RunnableConfig) -> Generator[Context, None, None]
192
189
  )
193
190
 
194
191
 
195
- def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
192
+ def ensure_config(config: RunnableConfig | None = None) -> RunnableConfig:
196
193
  """Ensure that a config is a dict with all keys present.
197
194
 
198
195
  Args:
199
- config (Optional[RunnableConfig], optional): The config to ensure.
200
- Defaults to None.
196
+ config: The config to ensure.
201
197
 
202
198
  Returns:
203
- RunnableConfig: The ensured config.
199
+ The ensured config.
204
200
  """
205
201
  empty = RunnableConfig(
206
202
  tags=[],
@@ -247,19 +243,18 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
247
243
 
248
244
 
249
245
  def get_config_list(
250
- config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]], length: int
246
+ config: RunnableConfig | Sequence[RunnableConfig] | None, length: int
251
247
  ) -> list[RunnableConfig]:
252
248
  """Get a list of configs from a single config or a list of configs.
253
249
 
254
250
  It is useful for subclasses overriding batch() or abatch().
255
251
 
256
252
  Args:
257
- config (Optional[Union[RunnableConfig, list[RunnableConfig]]]):
258
- The config or list of configs.
259
- length (int): The length of the list.
253
+ config: The config or list of configs.
254
+ length: The length of the list.
260
255
 
261
256
  Returns:
262
- list[RunnableConfig]: The list of configs.
257
+ The list of configs.
263
258
 
264
259
  Raises:
265
260
  ValueError: If the length of the list is not equal to the length of the inputs.
@@ -294,30 +289,26 @@ def get_config_list(
294
289
 
295
290
 
296
291
  def patch_config(
297
- config: Optional[RunnableConfig],
292
+ config: RunnableConfig | None,
298
293
  *,
299
- callbacks: Optional[BaseCallbackManager] = None,
300
- recursion_limit: Optional[int] = None,
301
- max_concurrency: Optional[int] = None,
302
- run_name: Optional[str] = None,
303
- configurable: Optional[dict[str, Any]] = None,
294
+ callbacks: BaseCallbackManager | None = None,
295
+ recursion_limit: int | None = None,
296
+ max_concurrency: int | None = None,
297
+ run_name: str | None = None,
298
+ configurable: dict[str, Any] | None = None,
304
299
  ) -> RunnableConfig:
305
300
  """Patch a config with new values.
306
301
 
307
302
  Args:
308
- config (Optional[RunnableConfig]): The config to patch.
309
- callbacks (Optional[BaseCallbackManager], optional): The callbacks to set.
310
- Defaults to None.
311
- recursion_limit (Optional[int], optional): The recursion limit to set.
312
- Defaults to None.
313
- max_concurrency (Optional[int], optional): The max concurrency to set.
314
- Defaults to None.
315
- run_name (Optional[str], optional): The run name to set. Defaults to None.
316
- configurable (Optional[dict[str, Any]], optional): The configurable to set.
317
- Defaults to None.
303
+ config: The config to patch.
304
+ callbacks: The callbacks to set.
305
+ recursion_limit: The recursion limit to set.
306
+ max_concurrency: The max concurrency to set.
307
+ run_name: The run name to set.
308
+ configurable: The configurable to set.
318
309
 
319
310
  Returns:
320
- RunnableConfig: The patched config.
311
+ The patched config.
321
312
  """
322
313
  config = ensure_config(config)
323
314
  if callbacks is not None:
@@ -339,14 +330,14 @@ def patch_config(
339
330
  return config
340
331
 
341
332
 
342
- def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
333
+ def merge_configs(*configs: RunnableConfig | None) -> RunnableConfig:
343
334
  """Merge multiple configs into one.
344
335
 
345
336
  Args:
346
- *configs (Optional[RunnableConfig]): The configs to merge.
337
+ *configs: The configs to merge.
347
338
 
348
339
  Returns:
349
- RunnableConfig: The merged config.
340
+ The merged config.
350
341
  """
351
342
  base: RunnableConfig = {}
352
343
  # Even though the keys aren't literals, this is correct
@@ -406,15 +397,13 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
406
397
 
407
398
 
408
399
  def call_func_with_variable_args(
409
- func: Union[
410
- Callable[[Input], Output],
411
- Callable[[Input, RunnableConfig], Output],
412
- Callable[[Input, CallbackManagerForChainRun], Output],
413
- Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
414
- ],
400
+ func: Callable[[Input], Output]
401
+ | Callable[[Input, RunnableConfig], Output]
402
+ | Callable[[Input, CallbackManagerForChainRun], Output]
403
+ | Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
415
404
  input: Input,
416
405
  config: RunnableConfig,
417
- run_manager: Optional[CallbackManagerForChainRun] = None,
406
+ run_manager: CallbackManagerForChainRun | None = None,
418
407
  **kwargs: Any,
419
408
  ) -> Output:
420
409
  """Call function that may optionally accept a run_manager and/or config.
@@ -423,7 +412,7 @@ def call_func_with_variable_args(
423
412
  func: The function to call.
424
413
  input: The input to the function.
425
414
  config: The config to pass to the function.
426
- run_manager: The run manager to pass to the function. Defaults to None.
415
+ run_manager: The run manager to pass to the function.
427
416
  **kwargs: The keyword arguments to pass to the function.
428
417
 
429
418
  Returns:
@@ -440,18 +429,15 @@ def call_func_with_variable_args(
440
429
 
441
430
 
442
431
  def acall_func_with_variable_args(
443
- func: Union[
444
- Callable[[Input], Awaitable[Output]],
445
- Callable[[Input, RunnableConfig], Awaitable[Output]],
446
- Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
447
- Callable[
448
- [Input, AsyncCallbackManagerForChainRun, RunnableConfig],
449
- Awaitable[Output],
450
- ],
432
+ func: Callable[[Input], Awaitable[Output]]
433
+ | Callable[[Input, RunnableConfig], Awaitable[Output]]
434
+ | Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]]
435
+ | Callable[
436
+ [Input, AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output]
451
437
  ],
452
438
  input: Input,
453
439
  config: RunnableConfig,
454
- run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
440
+ run_manager: AsyncCallbackManagerForChainRun | None = None,
455
441
  **kwargs: Any,
456
442
  ) -> Awaitable[Output]:
457
443
  """Async call function that may optionally accept a run_manager and/or config.
@@ -460,7 +446,7 @@ def acall_func_with_variable_args(
460
446
  func: The function to call.
461
447
  input: The input to the function.
462
448
  config: The config to pass to the function.
463
- run_manager: The run manager to pass to the function. Defaults to None.
449
+ run_manager: The run manager to pass to the function.
464
450
  **kwargs: The keyword arguments to pass to the function.
465
451
 
466
452
  Returns:
@@ -480,10 +466,10 @@ def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
480
466
  """Get a callback manager for a config.
481
467
 
482
468
  Args:
483
- config (RunnableConfig): The config.
469
+ config: The config.
484
470
 
485
471
  Returns:
486
- CallbackManager: The callback manager.
472
+ The callback manager.
487
473
  """
488
474
  return CallbackManager.configure(
489
475
  inheritable_callbacks=config.get("callbacks"),
@@ -498,10 +484,10 @@ def get_async_callback_manager_for_config(
498
484
  """Get an async callback manager for a config.
499
485
 
500
486
  Args:
501
- config (RunnableConfig): The config.
487
+ config: The config.
502
488
 
503
489
  Returns:
504
- AsyncCallbackManager: The async callback manager.
490
+ The async callback manager.
505
491
  """
506
492
  return AsyncCallbackManager.configure(
507
493
  inheritable_callbacks=config.get("callbacks"),
@@ -526,12 +512,12 @@ class ContextThreadPoolExecutor(ThreadPoolExecutor):
526
512
  """Submit a function to the executor.
527
513
 
528
514
  Args:
529
- func (Callable[..., T]): The function to submit.
530
- *args (Any): The positional arguments to the function.
531
- **kwargs (Any): The keyword arguments to the function.
515
+ func: The function to submit.
516
+ *args: The positional arguments to the function.
517
+ **kwargs: The keyword arguments to the function.
532
518
 
533
519
  Returns:
534
- Future[T]: The future for the function.
520
+ The future for the function.
535
521
  """
536
522
  return super().submit(
537
523
  cast("Callable[..., T]", partial(copy_context().run, func, *args, **kwargs))
@@ -541,20 +527,18 @@ class ContextThreadPoolExecutor(ThreadPoolExecutor):
541
527
  self,
542
528
  fn: Callable[..., T],
543
529
  *iterables: Iterable[Any],
544
- timeout: float | None = None,
545
- chunksize: int = 1,
530
+ **kwargs: Any,
546
531
  ) -> Iterator[T]:
547
532
  """Map a function to multiple iterables.
548
533
 
549
534
  Args:
550
- fn (Callable[..., T]): The function to map.
551
- *iterables (Iterable[Any]): The iterables to map over.
552
- timeout (float | None, optional): The timeout for the map.
553
- Defaults to None.
554
- chunksize (int, optional): The chunksize for the map. Defaults to 1.
535
+ fn: The function to map.
536
+ *iterables: The iterables to map over.
537
+ timeout: The timeout for the map.
538
+ chunksize: The chunksize for the map.
555
539
 
556
540
  Returns:
557
- Iterator[T]: The iterator for the mapped function.
541
+ The iterator for the mapped function.
558
542
  """
559
543
  contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
560
544
 
@@ -564,22 +548,21 @@ class ContextThreadPoolExecutor(ThreadPoolExecutor):
564
548
  return super().map(
565
549
  _wrapped_fn,
566
550
  *iterables,
567
- timeout=timeout,
568
- chunksize=chunksize,
551
+ **kwargs,
569
552
  )
570
553
 
571
554
 
572
555
  @contextmanager
573
556
  def get_executor_for_config(
574
- config: Optional[RunnableConfig],
557
+ config: RunnableConfig | None,
575
558
  ) -> Generator[Executor, None, None]:
576
559
  """Get an executor for a config.
577
560
 
578
561
  Args:
579
- config (RunnableConfig): The config.
562
+ config: The config.
580
563
 
581
564
  Yields:
582
- Generator[Executor, None, None]: The executor.
565
+ The executor.
583
566
  """
584
567
  config = config or {}
585
568
  with ContextThreadPoolExecutor(
@@ -589,7 +572,7 @@ def get_executor_for_config(
589
572
 
590
573
 
591
574
  async def run_in_executor(
592
- executor_or_config: Optional[Union[Executor, RunnableConfig]],
575
+ executor_or_config: Executor | RunnableConfig | None,
593
576
  func: Callable[P, T],
594
577
  *args: P.args,
595
578
  **kwargs: P.kwargs,
@@ -598,12 +581,12 @@ async def run_in_executor(
598
581
 
599
582
  Args:
600
583
  executor_or_config: The executor or config to run in.
601
- func (Callable[P, Output]): The function.
602
- *args (Any): The positional arguments to the function.
603
- **kwargs (Any): The keyword arguments to the function.
584
+ func: The function.
585
+ *args: The positional arguments to the function.
586
+ **kwargs: The keyword arguments to the function.
604
587
 
605
588
  Returns:
606
- Output: The output of the function.
589
+ The output of the function.
607
590
  """
608
591
 
609
592
  def wrapper() -> T: