langchain-dev-utils 1.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. langchain_dev_utils/__init__.py +1 -0
  2. langchain_dev_utils/_utils.py +131 -0
  3. langchain_dev_utils/agents/__init__.py +4 -0
  4. langchain_dev_utils/agents/factory.py +99 -0
  5. langchain_dev_utils/agents/file_system.py +252 -0
  6. langchain_dev_utils/agents/middleware/__init__.py +21 -0
  7. langchain_dev_utils/agents/middleware/format_prompt.py +66 -0
  8. langchain_dev_utils/agents/middleware/handoffs.py +214 -0
  9. langchain_dev_utils/agents/middleware/model_fallback.py +49 -0
  10. langchain_dev_utils/agents/middleware/model_router.py +200 -0
  11. langchain_dev_utils/agents/middleware/plan.py +367 -0
  12. langchain_dev_utils/agents/middleware/summarization.py +85 -0
  13. langchain_dev_utils/agents/middleware/tool_call_repair.py +96 -0
  14. langchain_dev_utils/agents/middleware/tool_emulator.py +60 -0
  15. langchain_dev_utils/agents/middleware/tool_selection.py +82 -0
  16. langchain_dev_utils/agents/plan.py +188 -0
  17. langchain_dev_utils/agents/wrap.py +324 -0
  18. langchain_dev_utils/chat_models/__init__.py +11 -0
  19. langchain_dev_utils/chat_models/adapters/__init__.py +3 -0
  20. langchain_dev_utils/chat_models/adapters/create_utils.py +53 -0
  21. langchain_dev_utils/chat_models/adapters/openai_compatible.py +715 -0
  22. langchain_dev_utils/chat_models/adapters/register_profiles.py +15 -0
  23. langchain_dev_utils/chat_models/base.py +282 -0
  24. langchain_dev_utils/chat_models/types.py +27 -0
  25. langchain_dev_utils/embeddings/__init__.py +11 -0
  26. langchain_dev_utils/embeddings/adapters/__init__.py +3 -0
  27. langchain_dev_utils/embeddings/adapters/create_utils.py +45 -0
  28. langchain_dev_utils/embeddings/adapters/openai_compatible.py +91 -0
  29. langchain_dev_utils/embeddings/base.py +234 -0
  30. langchain_dev_utils/message_convert/__init__.py +15 -0
  31. langchain_dev_utils/message_convert/content.py +201 -0
  32. langchain_dev_utils/message_convert/format.py +69 -0
  33. langchain_dev_utils/pipeline/__init__.py +7 -0
  34. langchain_dev_utils/pipeline/parallel.py +135 -0
  35. langchain_dev_utils/pipeline/sequential.py +101 -0
  36. langchain_dev_utils/pipeline/types.py +3 -0
  37. langchain_dev_utils/py.typed +0 -0
  38. langchain_dev_utils/tool_calling/__init__.py +14 -0
  39. langchain_dev_utils/tool_calling/human_in_the_loop.py +284 -0
  40. langchain_dev_utils/tool_calling/utils.py +81 -0
  41. langchain_dev_utils-1.3.7.dist-info/METADATA +103 -0
  42. langchain_dev_utils-1.3.7.dist-info/RECORD +44 -0
  43. langchain_dev_utils-1.3.7.dist-info/WHEEL +4 -0
  44. langchain_dev_utils-1.3.7.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,234 @@
1
+ from typing import Any, Literal, NotRequired, Optional, TypedDict, Union
2
+
3
+ from langchain.embeddings.base import _SUPPORTED_PROVIDERS, Embeddings, init_embeddings
4
+ from langchain_core.utils import from_env
5
+
6
+ from langchain_dev_utils._utils import (
7
+ _check_pkg_install,
8
+ _get_base_url_field_name,
9
+ _validate_provider_name,
10
+ )
11
+
12
+ _EMBEDDINGS_PROVIDERS_DICT = {}
13
+
14
+ EmbeddingsType = Union[type[Embeddings], Literal["openai-compatible"]]
15
+
16
+
17
+ class EmbeddingProvider(TypedDict):
18
+ provider_name: str
19
+ embeddings_model: EmbeddingsType
20
+ base_url: NotRequired[str]
21
+
22
+
23
+ def _parse_model_string(model_name: str) -> tuple[str, str]:
24
+ """Parse model string into provider and model name.
25
+
26
+ Args:
27
+ model_name: Model name string in format 'provider:model-name'
28
+
29
+ Returns:
30
+ Tuple of (provider, model) parsed from the model_name
31
+
32
+ Raises:
33
+ ValueError: If model name format is invalid or model name is empty
34
+ """
35
+ if ":" not in model_name:
36
+ msg = (
37
+ f"Invalid model format '{model_name}'.\n"
38
+ f"Model name must be in format 'provider:model-name'\n"
39
+ )
40
+ raise ValueError(msg)
41
+
42
+ provider, model = model_name.split(":", 1)
43
+ provider = provider.lower().strip()
44
+ model = model.strip()
45
+ if not model:
46
+ msg = "Model name cannot be empty"
47
+ raise ValueError(msg)
48
+ return provider, model
49
+
50
+
51
+ def register_embeddings_provider(
52
+ provider_name: str,
53
+ embeddings_model: EmbeddingsType,
54
+ base_url: Optional[str] = None,
55
+ ):
56
+ """Register an embeddings provider.
57
+
58
+ This function allows you to register custom embeddings providers that can be used
59
+ with the load_embeddings function. It supports both custom model classes and
60
+ string identifiers for supported providers.
61
+
62
+ Args:
63
+ provider_name: Name of the provider to register
64
+ embeddings_model: Either an Embeddings class or a string identifier
65
+ for a supported provider
66
+ base_url: The API address of the Embedding model provider (optional,
67
+ valid for both types of `embeddings_model`, but mainly used when
68
+ `embeddings_model` is a string and is "openai-compatible")
69
+
70
+ Raises:
71
+ ValueError: If base_url is not provided when embeddings_model is a string
72
+
73
+ Example:
74
+ # Register with custom model class:
75
+ >>> from langchain_dev_utils.embeddings import register_embeddings_provider, load_embeddings
76
+ >>> from langchain_core.embeddings.fake import FakeEmbeddings
77
+ >>>
78
+ >>> register_embeddings_provider("fakeembeddings", FakeEmbeddings)
79
+ >>> embeddings = load_embeddings("fakeembeddings:fake-embeddings",size=1024)
80
+ >>> embeddings.embed_query("hello world")
81
+
82
+ # Register with OpenAI-compatible API:
83
+ >>> register_embeddings_provider(
84
+ ... "vllm",
85
+ ... "openai-compatible",
86
+ ... base_url="http://localhost:8000/v1"
87
+ ... )
88
+ >>> embeddings = load_embeddings("vllm:qwen3-embedding-4b")
89
+ >>> embeddings.embed_query("hello world")
90
+ """
91
+ _validate_provider_name(provider_name)
92
+ base_url = base_url or from_env(f"{provider_name.upper()}_API_BASE", default=None)()
93
+ if isinstance(embeddings_model, str):
94
+ if base_url is None:
95
+ raise ValueError(
96
+ f"base_url must be provided or set {provider_name.upper()}_API_BASE environment variable when embeddings_model is a string"
97
+ )
98
+
99
+ if embeddings_model != "openai-compatible":
100
+ raise ValueError(
101
+ "when embeddings_model is a string, the value must be 'openai-compatible'"
102
+ )
103
+
104
+ _check_pkg_install("langchain_openai")
105
+ from .adapters.openai_compatible import _create_openai_compatible_embedding
106
+
107
+ embeddings_model = _create_openai_compatible_embedding(
108
+ provider=provider_name,
109
+ base_url=base_url,
110
+ )
111
+ _EMBEDDINGS_PROVIDERS_DICT.update(
112
+ {
113
+ provider_name: {
114
+ "embeddings_model": embeddings_model,
115
+ }
116
+ }
117
+ )
118
+ else:
119
+ if base_url is not None:
120
+ _EMBEDDINGS_PROVIDERS_DICT.update(
121
+ {
122
+ provider_name: {
123
+ "embeddings_model": embeddings_model,
124
+ "base_url": base_url,
125
+ }
126
+ }
127
+ )
128
+ else:
129
+ _EMBEDDINGS_PROVIDERS_DICT.update(
130
+ {provider_name: {"embeddings_model": embeddings_model}}
131
+ )
132
+
133
+
134
+ def batch_register_embeddings_provider(
135
+ providers: list[EmbeddingProvider],
136
+ ):
137
+ """Batch register embeddings providers.
138
+
139
+ This function allows you to register multiple embeddings providers at once,
140
+ which is useful when setting up applications that need to work with multiple
141
+ embedding services.
142
+
143
+ Args:
144
+ providers: List of EmbeddingProvider dictionaries, each containing:
145
+ - provider_name: str - Provider name
146
+ - embeddings_model: Union[Type[Embeddings], str] - Model class or
147
+ provider string
148
+ - base_url: The API address of the Embedding model provider
149
+ (optional, valid for both types of `embeddings_model`, but
150
+ mainly used when `embeddings_model` is a string and is
151
+ "openai-compatible")
152
+
153
+ Raises:
154
+ ValueError: If any of the providers are invalid
155
+
156
+ Example:
157
+ # Register multiple providers at once:
158
+ >>> from langchain_dev_utils.embeddings import batch_register_embeddings_provider, load_embeddings
159
+ >>> from langchain_core.embeddings.fake import FakeEmbeddings
160
+ >>>
161
+ >>> batch_register_embeddings_provider(
162
+ ... [
163
+ ... {
164
+ ... "provider_name": "fakeembeddings",
165
+ ... "embeddings_model": FakeEmbeddings,
166
+ ... },
167
+ ... {
168
+ ... "provider_name": "vllm",
169
+ ... "embeddings_model": "openai-compatible",
170
+ ... "base_url": "http://localhost:8000/v1"
171
+ ... },
172
+ ... ]
173
+ ... )
174
+ >>> embeddings = load_embeddings("vllm:qwen3-embedding-4b")
175
+ >>> embeddings.embed_query("hello world")
176
+ >>> embeddings = load_embeddings("fakeembeddings:fake-embeddings", size=1024)
177
+ >>> embeddings.embed_query("hello world")
178
+ """
179
+ for provider in providers:
180
+ register_embeddings_provider(
181
+ provider["provider_name"],
182
+ provider["embeddings_model"],
183
+ provider.get("base_url"),
184
+ )
185
+
186
+
187
+ def load_embeddings(
188
+ model: str,
189
+ *,
190
+ provider: Optional[str] = None,
191
+ **kwargs: Any,
192
+ ) -> Embeddings:
193
+ """Load embeddings model.
194
+
195
+ This function loads an embeddings model from the registered providers. The model parameter
196
+ must be specified in the format "provider:model-name" when provider is not specified separately.
197
+
198
+ Args:
199
+ model: Model name in format 'provider:model-name' if provider not specified separately
200
+ provider: Optional provider name (if not included in model parameter)
201
+ **kwargs: Additional arguments for model initialization (e.g., api_key)
202
+
203
+ Returns:
204
+ Embeddings: Initialized embeddings model instance
205
+
206
+ Raises:
207
+ ValueError: If provider is not registered or API key is not found
208
+
209
+ Example:
210
+ # Load model with provider prefix:
211
+ >>> from langchain_dev_utils.embeddings import load_embeddings
212
+ >>> embeddings = load_embeddings("vllm:qwen3-embedding-4b")
213
+ >>> embeddings.embed_query("hello world")
214
+
215
+ # Load model with separate provider parameter:
216
+ >>> embeddings = load_embeddings("qwen3-embedding-4b", provider="vllm")
217
+ >>> embeddings.embed_query("hello world")
218
+ """
219
+ if provider is None:
220
+ provider, model = _parse_model_string(model)
221
+ if provider not in list(_EMBEDDINGS_PROVIDERS_DICT.keys()) + list(
222
+ _SUPPORTED_PROVIDERS
223
+ ):
224
+ raise ValueError(f"Provider {provider} not registered")
225
+
226
+ if provider in _EMBEDDINGS_PROVIDERS_DICT:
227
+ embeddings = _EMBEDDINGS_PROVIDERS_DICT[provider]["embeddings_model"]
228
+ if base_url := _EMBEDDINGS_PROVIDERS_DICT[provider].get("base_url"):
229
+ url_key = _get_base_url_field_name(embeddings)
230
+ if url_key is not None:
231
+ kwargs.update({url_key: base_url})
232
+ return embeddings(model=model, **kwargs)
233
+ else:
234
+ return init_embeddings(model, provider=provider, **kwargs)
@@ -0,0 +1,15 @@
1
+ from .content import (
2
+ aconvert_reasoning_content_for_chunk_iterator,
3
+ convert_reasoning_content_for_ai_message,
4
+ convert_reasoning_content_for_chunk_iterator,
5
+ merge_ai_message_chunk,
6
+ )
7
+ from .format import format_sequence
8
+
9
+ __all__ = [
10
+ "convert_reasoning_content_for_ai_message",
11
+ "convert_reasoning_content_for_chunk_iterator",
12
+ "aconvert_reasoning_content_for_chunk_iterator",
13
+ "merge_ai_message_chunk",
14
+ "format_sequence",
15
+ ]
@@ -0,0 +1,201 @@
1
+ from functools import reduce
2
+ from typing import AsyncIterator, Iterator, Sequence, Tuple, cast
3
+
4
+ from langchain_core.messages import AIMessage, AIMessageChunk
5
+
6
+
7
+ def _get_reasoning_content(model_response: AIMessage | AIMessageChunk) -> str | None:
8
+ reasoning_content = None
9
+
10
+ reasoning_content_block = [
11
+ block for block in model_response.content_blocks if block["type"] == "reasoning"
12
+ ]
13
+ if reasoning_content_block:
14
+ reasoning_content = reasoning_content_block[0].get("reasoning")
15
+
16
+ if not reasoning_content:
17
+ reasoning_content = model_response.additional_kwargs.get("reasoning_content")
18
+
19
+ return reasoning_content
20
+
21
+
22
+ def convert_reasoning_content_for_ai_message(
23
+ model_response: AIMessage,
24
+ think_tag: Tuple[str, str] = ("<think>", "</think>"),
25
+ ) -> AIMessage:
26
+ """Convert reasoning content in AI message to visible content.
27
+
28
+ This function extracts reasoning content from the additional_kwargs of an AI message
29
+ and merges it into the visible content, wrapping it with the specified tags.
30
+
31
+ Args:
32
+ model_response: AI message response from model
33
+ think_tag: Tuple of (opening_tag, closing_tag) to wrap reasoning content
34
+
35
+ Returns:
36
+ AIMessage: Modified AI message with reasoning content in visible content
37
+
38
+ Example:
39
+ # Basic usage with default tags:
40
+ >>> from langchain_dev_utils.message_convert import convert_reasoning_content_for_ai_message
41
+ >>> response = model.invoke("Explain quantum computing")
42
+ >>> response = convert_reasoning_content_for_ai_message(response)
43
+ >>> response.content
44
+
45
+ # Custom tags for reasoning content:
46
+ >>> response = convert_reasoning_content_for_ai_message(
47
+ ... response, think_tag=('<reasoning>', '</reasoning>')
48
+ ... )
49
+ >>> response.content
50
+ """
51
+
52
+ reasoning_content = _get_reasoning_content(model_response)
53
+
54
+ if reasoning_content:
55
+ return model_response.model_copy(
56
+ update={
57
+ "content": f"{think_tag[0]}{reasoning_content}{think_tag[1]}{model_response.content}"
58
+ }
59
+ )
60
+ return model_response
61
+
62
+
63
+ def convert_reasoning_content_for_chunk_iterator(
64
+ model_response: Iterator[AIMessageChunk | AIMessage],
65
+ think_tag: Tuple[str, str] = ("<think>", "</think>"),
66
+ ) -> Iterator[AIMessageChunk | AIMessage]:
67
+ """Convert reasoning content for streaming response chunks.
68
+
69
+ This function processes streaming response chunks and merges reasoning content
70
+ into the visible content, wrapping it with the specified tags. It handles
71
+ the first chunk, middle chunks, and last chunk differently to properly
72
+ format the reasoning content.
73
+
74
+ Args:
75
+ model_response: Iterator of message chunks from streaming response
76
+ think_tag: Tuple of (opening_tag, closing_tag) to wrap reasoning content
77
+
78
+ Yields:
79
+ BaseMessageChunk: Modified message chunks with reasoning content
80
+
81
+ Example:
82
+ # Process streaming response:
83
+ >>> from langchain_dev_utils.message_convert import convert_reasoning_content_for_chunk_iterator
84
+ >>> for chunk in convert_reasoning_content_for_chunk_iterator(
85
+ ... model.stream("What is the capital of France?")
86
+ ... ):
87
+ ... print(chunk.content, end="", flush=True)
88
+
89
+ # Custom tags for streaming:
90
+ >>> for chunk in convert_reasoning_content_for_chunk_iterator(
91
+ ... model.stream("Explain quantum computing"),
92
+ ... think_tag=('<reasoning>', '</reasoning>')
93
+ ... ):
94
+ ... print(chunk.content, end="", flush=True)
95
+ """
96
+ isfirst = True
97
+ isend = True
98
+
99
+ for chunk in model_response:
100
+ if isinstance(chunk, AIMessageChunk):
101
+ reasoning_content = _get_reasoning_content(chunk)
102
+ if reasoning_content:
103
+ if isfirst:
104
+ chunk = chunk.model_copy(
105
+ update={"content": f"{think_tag[0]}{reasoning_content}"}
106
+ )
107
+ isfirst = False
108
+ else:
109
+ chunk = chunk.model_copy(update={"content": reasoning_content})
110
+ elif chunk.content and isend and not isfirst:
111
+ chunk = chunk.model_copy(
112
+ update={"content": f"{think_tag[1]}{chunk.content}"}
113
+ )
114
+ isend = False
115
+ yield chunk
116
+
117
+
118
+ async def aconvert_reasoning_content_for_chunk_iterator(
119
+ model_response: AsyncIterator[AIMessageChunk | AIMessage],
120
+ think_tag: Tuple[str, str] = ("<think>", "</think>"),
121
+ ) -> AsyncIterator[AIMessageChunk | AIMessage]:
122
+ """Async convert reasoning content for streaming response chunks.
123
+
124
+ This is the async version of convert_reasoning_content_for_chunk_iterator.
125
+ It processes async streaming response chunks and merges reasoning content
126
+ into the visible content, wrapping it with the specified tags.
127
+
128
+ Args:
129
+ model_response: Async iterator of message chunks from streaming response
130
+ think_tag: Tuple of (opening_tag, closing_tag) to wrap reasoning content
131
+
132
+ Yields:
133
+ BaseMessageChunk: Modified message chunks with reasoning content
134
+
135
+ Example:
136
+ # Process async streaming response:
137
+ >>> from langchain_dev_utils.message_convert import aconvert_reasoning_content_for_chunk_iterator
138
+ >>> async for chunk in aconvert_reasoning_content_for_chunk_iterator(
139
+ ... model.astream("What is the capital of France?")
140
+ ... ):
141
+ ... print(chunk.content, end="", flush=True)
142
+
143
+ # Custom tags for async streaming:
144
+ >>> async for chunk in aconvert_reasoning_content_for_chunk_iterator(
145
+ ... model.astream("Explain quantum computing"),
146
+ ... think_tag=('<reasoning>', '</reasoning>')
147
+ ... ):
148
+ ... print(chunk.content, end="", flush=True)
149
+ """
150
+ isfirst = True
151
+ isend = True
152
+
153
+ async for chunk in model_response:
154
+ if isinstance(chunk, AIMessageChunk):
155
+ reasoning_content = _get_reasoning_content(chunk)
156
+ if reasoning_content:
157
+ if isfirst:
158
+ chunk = chunk.model_copy(
159
+ update={"content": f"{think_tag[0]}{reasoning_content}"}
160
+ )
161
+ isfirst = False
162
+ else:
163
+ chunk = chunk.model_copy(update={"content": reasoning_content})
164
+ elif chunk.content and isend and not isfirst:
165
+ chunk = chunk.model_copy(
166
+ update={"content": f"{think_tag[1]}{chunk.content}"}
167
+ )
168
+ isend = False
169
+ yield chunk
170
+
171
+
172
+ def merge_ai_message_chunk(chunks: Sequence[AIMessageChunk]) -> AIMessage:
173
+ """Merge a sequence of AIMessageChunk into a single AIMessage.
174
+
175
+ This function combines multiple message chunks into a single message,
176
+ preserving the content and metadata while handling tool calls appropriately.
177
+
178
+ Args:
179
+ chunks: Sequence of AIMessageChunk to merge
180
+
181
+ Returns:
182
+ AIMessage: Merged AIMessage
183
+
184
+ Example:
185
+ # Merge streaming chunks:
186
+ >>> from langchain_dev_utils.message_convert import merge_ai_message_chunk
187
+ >>> merged_message = merge_ai_message_chunk(list(model.stream("What is the capital of France?")))
188
+ >>> merged_message.content
189
+ """
190
+ ai_message_chunk = cast(AIMessageChunk, reduce(lambda x, y: x + y, chunks))
191
+ ai_message_chunk.additional_kwargs.pop("tool_calls", None)
192
+
193
+ data = {
194
+ "id": ai_message_chunk.id,
195
+ "content": ai_message_chunk.content,
196
+ "response_metadata": ai_message_chunk.response_metadata,
197
+ "additional_kwargs": ai_message_chunk.additional_kwargs,
198
+ }
199
+ if hasattr(ai_message_chunk, "tool_calls") and len(ai_message_chunk.tool_calls):
200
+ data["tool_calls"] = ai_message_chunk.tool_calls
201
+ return AIMessage.model_validate(data)
@@ -0,0 +1,69 @@
1
+ from typing import Sequence, Union
2
+
3
+ from langchain_core.documents import Document
4
+ from langchain_core.messages import (
5
+ AIMessage,
6
+ BaseMessage,
7
+ HumanMessage,
8
+ SystemMessage,
9
+ ToolMessage,
10
+ )
11
+
12
+
13
+ def format_sequence(
14
+ inputs: Union[Sequence[Document], Sequence[BaseMessage], Sequence[str]],
15
+ separator: str = "-",
16
+ with_num: bool = False,
17
+ ) -> str:
18
+ """Convert a list of messages, documents, or strings into a formatted string.
19
+
20
+ This function extracts text content from various types (e.g., HumanMessage, Document)
21
+ and joins them into a single string. Optionally adds serial numbers and a custom
22
+ separator between items.
23
+
24
+ Args:
25
+ inputs: A list of inputs. Supported types:
26
+ - langchain_core.messages: HumanMessage, AIMessage, SystemMessage, ToolMessage
27
+ - langchain_core.documents.Document
28
+ - str
29
+ separator: The separator used to join the items. Defaults to "-".
30
+ with_num: If True, prefixes each item with a serial number (e.g., "1. Hello").
31
+ Defaults to False.
32
+
33
+ Returns:
34
+ A formatted string composed of the input contents, joined by `separator`.
35
+
36
+ Example:
37
+ # Format messages with default separator:
38
+ >>> from langchain_dev_utils.message_convert import format_sequence
39
+ >>> from langchain_core.messages import HumanMessage, AIMessage
40
+ >>> messages = [
41
+ ... HumanMessage(content="Hello, how are you?"),
42
+ ... AIMessage(content="I'm doing well, thank you!")
43
+ ... ]
44
+ >>> formatted = format_sequence(messages)
45
+ >>> formatted
46
+
47
+ # Format with custom separator and numbering:
48
+ >>> formatted = format_sequence(messages, separator="---", with_num=True)
49
+ >>> formatted
50
+ """
51
+ if not inputs:
52
+ return ""
53
+
54
+ outputs = []
55
+
56
+ for input_item in inputs:
57
+ if isinstance(
58
+ input_item, (HumanMessage, AIMessage, SystemMessage, ToolMessage)
59
+ ):
60
+ outputs.append(input_item.content)
61
+ elif isinstance(input_item, Document):
62
+ outputs.append(input_item.page_content)
63
+ elif isinstance(input_item, str):
64
+ outputs.append(input_item)
65
+ if with_num:
66
+ outputs = [f"{i + 1}. {output}" for i, output in enumerate(outputs)]
67
+
68
+ str_ = "\n" + separator
69
+ return separator + str_.join(outputs)
@@ -0,0 +1,7 @@
1
+ from .parallel import create_parallel_pipeline
2
+ from .sequential import create_sequential_pipeline
3
+
4
+ __all__ = [
5
+ "create_parallel_pipeline",
6
+ "create_sequential_pipeline",
7
+ ]
@@ -0,0 +1,135 @@
1
+ from typing import Awaitable, Callable, Optional, Union
2
+
3
+ from langgraph.cache.base import BaseCache
4
+ from langgraph.graph import StateGraph
5
+ from langgraph.graph.state import CompiledStateGraph
6
+ from langgraph.store.base import BaseStore
7
+ from langgraph.types import Checkpointer, Send
8
+ from langgraph.typing import ContextT, InputT, OutputT, StateT
9
+
10
+ from .types import SubGraph
11
+
12
+
13
+ def create_parallel_pipeline(
14
+ sub_graphs: list[SubGraph],
15
+ state_schema: type[StateT],
16
+ graph_name: Optional[str] = None,
17
+ branches_fn: Optional[
18
+ Union[
19
+ Callable[..., list[Send]],
20
+ Callable[..., Awaitable[list[Send]]],
21
+ ]
22
+ ] = None,
23
+ context_schema: type[ContextT] | None = None,
24
+ input_schema: type[InputT] | None = None,
25
+ output_schema: type[OutputT] | None = None,
26
+ checkpointer: Checkpointer | None = None,
27
+ store: BaseStore | None = None,
28
+ cache: BaseCache | None = None,
29
+ ) -> CompiledStateGraph[StateT, ContextT, InputT, OutputT]:
30
+ """
31
+ Create a parallel pipeline from a list of subgraphs.
32
+
33
+ This function allows you to compose multiple StateGraphs in a parallel fashion,
34
+ where subgraphs can execute concurrently. This is useful for creating complex
35
+ multi-agent workflows where agents can work independently or with dynamic branching.
36
+
37
+ Args:
38
+ sub_graphs: List of sub-graphs to execute in parallel
39
+ state_schema: state schema of the final constructed graph
40
+ graph_name: Name of the final constructed graph
41
+ branches_fn: Optional function to determine which sub-graphs to execute
42
+ in parallel
43
+ context_schema: context schema of the final constructed graph
44
+ input_schema: input schema of the final constructed graph
45
+ output_schema: output schema of the final constructed graph
46
+ checkpointer: Optional LangGraph checkpointer for the final constructed
47
+ graph
48
+ store: Optional LangGraph store for the final constructed graph
49
+ cache: Optional LangGraph cache for the final constructed graph
50
+
51
+ Returns:
52
+ CompiledStateGraph[StateT, ContextT, InputT, OutputT]: Compiled state
53
+ graph of the pipeline.
54
+
55
+ Example:
56
+ # Basic parallel pipeline: multiple specialized agents run concurrently
57
+ >>> from langchain_dev_utils.pipeline import create_parallel_pipeline
58
+ >>>
59
+ >>> graph = create_parallel_pipeline(
60
+ ... sub_graphs=[
61
+ ... time_agent, weather_agent, user_agent
62
+ ... ],
63
+ ... state_schema=AgentState,
64
+ ... graph_name="parallel_agents_pipeline",
65
+ ... )
66
+ >>>
67
+ >>> response = graph.invoke({"messages": [HumanMessage("Hello")]})
68
+
69
+ # Dynamic parallel pipeline: decide which agents to run based on conditional branches
70
+ >>> graph = create_parallel_pipeline(
71
+ ... sub_graphs=[
72
+ ... time_agent, weather_agent, user_agent
73
+ ... ],
74
+ ... state_schema=AgentState,
75
+ ... branches_fn=lambda state: [
76
+ ... Send("weather_agent", arg={"messages": [HumanMessage("Get current weather in New York")]}),
77
+ ... Send("time_agent", arg={"messages": [HumanMessage("Get current time")]}),
78
+ ... ],
79
+ ... graph_name="dynamic_parallel_pipeline",
80
+ ... )
81
+ >>>
82
+ >>> response = graph.invoke({"messages": [HumanMessage("Hello")]})
83
+ """
84
+ graph = StateGraph(
85
+ state_schema=state_schema,
86
+ context_schema=context_schema,
87
+ input_schema=input_schema,
88
+ output_schema=output_schema,
89
+ )
90
+
91
+ subgraphs_names = set()
92
+
93
+ compiled_subgraphs: list[CompiledStateGraph] = []
94
+ for subgraph in sub_graphs:
95
+ if isinstance(subgraph, StateGraph):
96
+ subgraph = subgraph.compile()
97
+
98
+ compiled_subgraphs.append(subgraph)
99
+ if subgraph.name is None or subgraph.name == "LangGraph":
100
+ raise ValueError(
101
+ "Please specify a name when you create your agent, either via `create_react_agent(..., name=agent_name)` "
102
+ "or via `graph.compile(name=name)`."
103
+ )
104
+
105
+ if subgraph.name in subgraphs_names:
106
+ raise ValueError(
107
+ f"Subgraph with name '{subgraph.name}' already exists. Subgraph names must be unique."
108
+ )
109
+
110
+ subgraphs_names.add(subgraph.name)
111
+
112
+ for sub_graph in compiled_subgraphs:
113
+ graph.add_node(sub_graph.name, sub_graph)
114
+
115
+ if branches_fn:
116
+ graph.add_conditional_edges(
117
+ "__start__",
118
+ branches_fn,
119
+ [subgraph.name for subgraph in compiled_subgraphs],
120
+ )
121
+ return graph.compile(
122
+ name=graph_name or "parallel graph",
123
+ checkpointer=checkpointer,
124
+ store=store,
125
+ cache=cache,
126
+ )
127
+ else:
128
+ for i in range(len(compiled_subgraphs)):
129
+ graph.add_edge("__start__", compiled_subgraphs[i].name)
130
+ return graph.compile(
131
+ name=graph_name or "parallel graph",
132
+ checkpointer=checkpointer,
133
+ store=store,
134
+ cache=cache,
135
+ )