langchain 1.0.0a9__py3-none-any.whl → 1.0.0a11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain might be problematic. Click here for more details.

Files changed (36) hide show
  1. langchain/__init__.py +1 -24
  2. langchain/_internal/_documents.py +1 -1
  3. langchain/_internal/_prompts.py +2 -2
  4. langchain/_internal/_typing.py +1 -1
  5. langchain/agents/__init__.py +2 -3
  6. langchain/agents/factory.py +1126 -0
  7. langchain/agents/middleware/__init__.py +38 -1
  8. langchain/agents/middleware/context_editing.py +245 -0
  9. langchain/agents/middleware/human_in_the_loop.py +67 -20
  10. langchain/agents/middleware/model_call_limit.py +177 -0
  11. langchain/agents/middleware/model_fallback.py +94 -0
  12. langchain/agents/middleware/pii.py +753 -0
  13. langchain/agents/middleware/planning.py +201 -0
  14. langchain/agents/middleware/prompt_caching.py +7 -4
  15. langchain/agents/middleware/summarization.py +2 -1
  16. langchain/agents/middleware/tool_call_limit.py +260 -0
  17. langchain/agents/middleware/tool_selection.py +306 -0
  18. langchain/agents/middleware/types.py +708 -127
  19. langchain/agents/structured_output.py +15 -1
  20. langchain/chat_models/base.py +22 -25
  21. langchain/embeddings/base.py +3 -4
  22. langchain/embeddings/cache.py +0 -1
  23. langchain/messages/__init__.py +29 -0
  24. langchain/rate_limiters/__init__.py +13 -0
  25. langchain/tools/__init__.py +9 -0
  26. langchain/{agents → tools}/tool_node.py +8 -10
  27. {langchain-1.0.0a9.dist-info → langchain-1.0.0a11.dist-info}/METADATA +29 -35
  28. langchain-1.0.0a11.dist-info/RECORD +43 -0
  29. {langchain-1.0.0a9.dist-info → langchain-1.0.0a11.dist-info}/WHEEL +1 -1
  30. langchain/agents/middleware_agent.py +0 -617
  31. langchain/agents/react_agent.py +0 -1228
  32. langchain/globals.py +0 -18
  33. langchain/text_splitter.py +0 -50
  34. langchain-1.0.0a9.dist-info/RECORD +0 -38
  35. langchain-1.0.0a9.dist-info/entry_points.txt +0 -4
  36. {langchain-1.0.0a9.dist-info → langchain-1.0.0a11.dist-info}/licenses/LICENSE +0 -0
@@ -1,617 +0,0 @@
1
- """Middleware agent implementation."""
2
-
3
- import itertools
4
- from collections.abc import Callable, Sequence
5
- from inspect import signature
6
- from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
7
-
8
- from langchain_core.language_models.chat_models import BaseChatModel
9
- from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
10
- from langchain_core.runnables import Runnable
11
- from langchain_core.tools import BaseTool
12
- from langgraph.constants import END, START
13
- from langgraph.graph.state import StateGraph
14
- from langgraph.runtime import Runtime
15
- from langgraph.types import Send
16
- from langgraph.typing import ContextT
17
- from typing_extensions import NotRequired, Required, TypedDict, TypeVar
18
-
19
- from langchain.agents.middleware.types import (
20
- AgentMiddleware,
21
- AgentState,
22
- JumpTo,
23
- ModelRequest,
24
- OmitFromSchema,
25
- PublicAgentState,
26
- )
27
- from langchain.agents.structured_output import (
28
- MultipleStructuredOutputsError,
29
- OutputToolBinding,
30
- ProviderStrategy,
31
- ProviderStrategyBinding,
32
- ResponseFormat,
33
- StructuredOutputValidationError,
34
- ToolStrategy,
35
- )
36
- from langchain.agents.tool_node import ToolNode
37
- from langchain.chat_models import init_chat_model
38
-
39
- STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
40
-
41
-
42
- def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
43
- """Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
44
-
45
- Args:
46
- schemas: List of schema types to merge
47
- schema_name: Name for the generated TypedDict
48
- omit_flag: If specified, omit fields with this flag set ('input' or 'output')
49
- """
50
- all_annotations = {}
51
-
52
- for schema in schemas:
53
- hints = get_type_hints(schema, include_extras=True)
54
-
55
- for field_name, field_type in hints.items():
56
- should_omit = False
57
-
58
- if omit_flag:
59
- # Check for omission in the annotation metadata
60
- metadata = _extract_metadata(field_type)
61
- for meta in metadata:
62
- if isinstance(meta, OmitFromSchema) and getattr(meta, omit_flag) is True:
63
- should_omit = True
64
- break
65
-
66
- if not should_omit:
67
- all_annotations[field_name] = field_type
68
-
69
- return TypedDict(schema_name, all_annotations) # type: ignore[operator]
70
-
71
-
72
- def _extract_metadata(type_: type) -> list:
73
- """Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers."""
74
- # Handle Required[Annotated[...]] or NotRequired[Annotated[...]]
75
- if get_origin(type_) in (Required, NotRequired):
76
- inner_type = get_args(type_)[0]
77
- if get_origin(inner_type) is Annotated:
78
- return list(get_args(inner_type)[1:])
79
-
80
- # Handle direct Annotated[...]
81
- elif get_origin(type_) is Annotated:
82
- return list(get_args(type_)[1:])
83
-
84
- return []
85
-
86
-
87
- def _supports_native_structured_output(model: str | BaseChatModel) -> bool:
88
- """Check if a model supports native structured output."""
89
- model_name: str | None = None
90
- if isinstance(model, str):
91
- model_name = model
92
- elif isinstance(model, BaseChatModel):
93
- model_name = getattr(model, "model_name", None)
94
-
95
- return (
96
- "grok" in model_name.lower()
97
- or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
98
- if model_name
99
- else False
100
- )
101
-
102
-
103
- def _handle_structured_output_error(
104
- exception: Exception,
105
- response_format: ResponseFormat,
106
- ) -> tuple[bool, str]:
107
- """Handle structured output error. Returns (should_retry, retry_tool_message)."""
108
- if not isinstance(response_format, ToolStrategy):
109
- return False, ""
110
-
111
- handle_errors = response_format.handle_errors
112
-
113
- if handle_errors is False:
114
- return False, ""
115
- if handle_errors is True:
116
- return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
117
- if isinstance(handle_errors, str):
118
- return True, handle_errors
119
- if isinstance(handle_errors, type) and issubclass(handle_errors, Exception):
120
- if isinstance(exception, handle_errors):
121
- return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
122
- return False, ""
123
- if isinstance(handle_errors, tuple):
124
- if any(isinstance(exception, exc_type) for exc_type in handle_errors):
125
- return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
126
- return False, ""
127
- if callable(handle_errors):
128
- # type narrowing not working appropriately w/ callable check, can fix later
129
- return True, handle_errors(exception) # type: ignore[return-value,call-arg]
130
- return False, ""
131
-
132
-
133
- ResponseT = TypeVar("ResponseT")
134
-
135
-
136
- def create_agent( # noqa: PLR0915
137
- *,
138
- model: str | BaseChatModel,
139
- tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
140
- system_prompt: str | None = None,
141
- middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
142
- response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
143
- context_schema: type[ContextT] | None = None,
144
- ) -> StateGraph[
145
- AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
146
- ]:
147
- """Create a middleware agent graph."""
148
- # init chat model
149
- if isinstance(model, str):
150
- model = init_chat_model(model)
151
-
152
- # Handle tools being None or empty
153
- if tools is None:
154
- tools = []
155
-
156
- # Setup structured output
157
- structured_output_tools: dict[str, OutputToolBinding] = {}
158
- native_output_binding: ProviderStrategyBinding | None = None
159
-
160
- if response_format is not None:
161
- if not isinstance(response_format, (ToolStrategy, ProviderStrategy)):
162
- # Auto-detect strategy based on model capabilities
163
- if _supports_native_structured_output(model):
164
- response_format = ProviderStrategy(schema=response_format)
165
- else:
166
- response_format = ToolStrategy(schema=response_format)
167
-
168
- if isinstance(response_format, ToolStrategy):
169
- # Setup tools strategy for structured output
170
- for response_schema in response_format.schema_specs:
171
- structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
172
- structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
173
- elif isinstance(response_format, ProviderStrategy):
174
- # Setup native strategy
175
- native_output_binding = ProviderStrategyBinding.from_schema_spec(
176
- response_format.schema_spec
177
- )
178
- middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
179
-
180
- # Setup tools
181
- tool_node: ToolNode | None = None
182
- if isinstance(tools, list):
183
- # Extract builtin provider tools (dict format)
184
- builtin_tools = [t for t in tools if isinstance(t, dict)]
185
- regular_tools = [t for t in tools if not isinstance(t, dict)]
186
-
187
- # Add structured output tools to regular tools
188
- structured_tools = [info.tool for info in structured_output_tools.values()]
189
- all_tools = middleware_tools + regular_tools + structured_tools
190
-
191
- # Only create ToolNode if we have tools
192
- tool_node = ToolNode(tools=all_tools) if all_tools else None
193
- default_tools = regular_tools + builtin_tools + structured_tools + middleware_tools
194
- elif isinstance(tools, ToolNode):
195
- # tools is ToolNode or None
196
- tool_node = tools
197
- if tool_node:
198
- default_tools = list(tool_node.tools_by_name.values()) + middleware_tools
199
- # Update tool node to know about tools provided by middleware
200
- all_tools = list(tool_node.tools_by_name.values()) + middleware_tools
201
- tool_node = ToolNode(all_tools)
202
- # Add structured output tools
203
- for info in structured_output_tools.values():
204
- default_tools.append(info.tool)
205
- else:
206
- default_tools = (
207
- list(structured_output_tools.values()) if structured_output_tools else []
208
- ) + middleware_tools
209
-
210
- # validate middleware
211
- assert len({m.__class__.__name__ for m in middleware}) == len(middleware), ( # noqa: S101
212
- "Please remove duplicate middleware instances."
213
- )
214
- middleware_w_before = [
215
- m for m in middleware if m.__class__.before_model is not AgentMiddleware.before_model
216
- ]
217
- middleware_w_modify_model_request = [
218
- m
219
- for m in middleware
220
- if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
221
- ]
222
- middleware_w_after = [
223
- m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
224
- ]
225
-
226
- state_schemas = {m.state_schema for m in middleware}
227
- state_schemas.add(AgentState)
228
-
229
- state_schema = _resolve_schema(state_schemas, "StateSchema", None)
230
- input_schema = _resolve_schema(state_schemas, "InputSchema", "input")
231
- output_schema = _resolve_schema(state_schemas, "OutputSchema", "output")
232
-
233
- # create graph, add nodes
234
- graph: StateGraph[
235
- AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
236
- ] = StateGraph(
237
- state_schema=state_schema,
238
- input_schema=input_schema,
239
- output_schema=output_schema,
240
- context_schema=context_schema,
241
- )
242
-
243
- def _handle_model_output(output: AIMessage) -> dict[str, Any]:
244
- """Handle model output including structured responses."""
245
- # Handle structured output with native strategy
246
- if isinstance(response_format, ProviderStrategy):
247
- if not output.tool_calls and native_output_binding:
248
- structured_response = native_output_binding.parse(output)
249
- return {"messages": [output], "response": structured_response}
250
- return {"messages": [output]}
251
-
252
- # Handle structured output with tools strategy
253
- if (
254
- isinstance(response_format, ToolStrategy)
255
- and isinstance(output, AIMessage)
256
- and output.tool_calls
257
- ):
258
- structured_tool_calls = [
259
- tc for tc in output.tool_calls if tc["name"] in structured_output_tools
260
- ]
261
-
262
- if structured_tool_calls:
263
- exception: Exception | None = None
264
- if len(structured_tool_calls) > 1:
265
- # Handle multiple structured outputs error
266
- tool_names = [tc["name"] for tc in structured_tool_calls]
267
- exception = MultipleStructuredOutputsError(tool_names)
268
- should_retry, error_message = _handle_structured_output_error(
269
- exception, response_format
270
- )
271
- if not should_retry:
272
- raise exception
273
-
274
- # Add error messages and retry
275
- tool_messages = [
276
- ToolMessage(
277
- content=error_message,
278
- tool_call_id=tc["id"],
279
- name=tc["name"],
280
- )
281
- for tc in structured_tool_calls
282
- ]
283
- return {"messages": [output, *tool_messages]}
284
-
285
- # Handle single structured output
286
- tool_call = structured_tool_calls[0]
287
- try:
288
- structured_tool_binding = structured_output_tools[tool_call["name"]]
289
- structured_response = structured_tool_binding.parse(tool_call["args"])
290
-
291
- tool_message_content = (
292
- response_format.tool_message_content
293
- if response_format.tool_message_content
294
- else f"Returning structured response: {structured_response}"
295
- )
296
-
297
- return {
298
- "messages": [
299
- output,
300
- ToolMessage(
301
- content=tool_message_content,
302
- tool_call_id=tool_call["id"],
303
- name=tool_call["name"],
304
- ),
305
- ],
306
- "response": structured_response,
307
- }
308
- except Exception as exc: # noqa: BLE001
309
- exception = StructuredOutputValidationError(tool_call["name"], exc)
310
- should_retry, error_message = _handle_structured_output_error(
311
- exception, response_format
312
- )
313
- if not should_retry:
314
- raise exception
315
-
316
- return {
317
- "messages": [
318
- output,
319
- ToolMessage(
320
- content=error_message,
321
- tool_call_id=tool_call["id"],
322
- name=tool_call["name"],
323
- ),
324
- ],
325
- }
326
-
327
- return {"messages": [output]}
328
-
329
- def _get_bound_model(request: ModelRequest) -> Runnable:
330
- """Get the model with appropriate tool bindings."""
331
- if isinstance(response_format, ProviderStrategy):
332
- # Use native structured output
333
- kwargs = response_format.to_model_kwargs()
334
- return request.model.bind_tools(
335
- request.tools, strict=True, **kwargs, **request.model_settings
336
- )
337
- if isinstance(response_format, ToolStrategy):
338
- tool_choice = "any" if structured_output_tools else request.tool_choice
339
- return request.model.bind_tools(
340
- request.tools, tool_choice=tool_choice, **request.model_settings
341
- )
342
- # Standard model binding
343
- if request.tools:
344
- return request.model.bind_tools(
345
- request.tools, tool_choice=request.tool_choice, **request.model_settings
346
- )
347
- return request.model.bind(**request.model_settings)
348
-
349
- model_request_signatures: list[
350
- tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]]
351
- ] = [
352
- ("runtime" in signature(m.modify_model_request).parameters, m)
353
- for m in middleware_w_modify_model_request
354
- ]
355
-
356
- def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
357
- """Sync model request handler with sequential middleware processing."""
358
- request = ModelRequest(
359
- model=model,
360
- tools=default_tools,
361
- system_prompt=system_prompt,
362
- response_format=response_format,
363
- messages=state["messages"],
364
- tool_choice=None,
365
- )
366
-
367
- # Apply modify_model_request middleware in sequence
368
- for use_runtime, m in model_request_signatures:
369
- if use_runtime:
370
- m.modify_model_request(request, state, runtime)
371
- else:
372
- m.modify_model_request(request, state) # type: ignore[call-arg]
373
-
374
- # Get the final model and messages
375
- model_ = _get_bound_model(request)
376
- messages = request.messages
377
- if request.system_prompt:
378
- messages = [SystemMessage(request.system_prompt), *messages]
379
-
380
- output = model_.invoke(messages)
381
- return _handle_model_output(output)
382
-
383
- async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
384
- """Async model request handler with sequential middleware processing."""
385
- # Start with the base model request
386
- request = ModelRequest(
387
- model=model,
388
- tools=default_tools,
389
- system_prompt=system_prompt,
390
- response_format=response_format,
391
- messages=state["messages"],
392
- tool_choice=None,
393
- )
394
-
395
- # Apply modify_model_request middleware in sequence
396
- for use_runtime, m in model_request_signatures:
397
- if use_runtime:
398
- m.modify_model_request(request, state, runtime)
399
- else:
400
- m.modify_model_request(request, state) # type: ignore[call-arg]
401
-
402
- # Get the final model and messages
403
- model_ = _get_bound_model(request)
404
- messages = request.messages
405
- if request.system_prompt:
406
- messages = [SystemMessage(request.system_prompt), *messages]
407
-
408
- output = await model_.ainvoke(messages)
409
- return _handle_model_output(output)
410
-
411
- # Use sync or async based on model capabilities
412
- from langgraph._internal._runnable import RunnableCallable
413
-
414
- graph.add_node("model_request", RunnableCallable(model_request, amodel_request))
415
-
416
- # Only add tools node if we have tools
417
- if tool_node is not None:
418
- graph.add_node("tools", tool_node)
419
-
420
- # Add middleware nodes
421
- for m in middleware:
422
- if m.__class__.before_model is not AgentMiddleware.before_model:
423
- graph.add_node(
424
- f"{m.__class__.__name__}.before_model", m.before_model, input_schema=state_schema
425
- )
426
-
427
- if m.__class__.after_model is not AgentMiddleware.after_model:
428
- graph.add_node(
429
- f"{m.__class__.__name__}.after_model", m.after_model, input_schema=state_schema
430
- )
431
-
432
- # add start edge
433
- first_node = (
434
- f"{middleware_w_before[0].__class__.__name__}.before_model"
435
- if middleware_w_before
436
- else "model_request"
437
- )
438
- last_node = (
439
- f"{middleware_w_after[0].__class__.__name__}.after_model"
440
- if middleware_w_after
441
- else "model_request"
442
- )
443
- graph.add_edge(START, first_node)
444
-
445
- # add conditional edges only if tools exist
446
- if tool_node is not None:
447
- graph.add_conditional_edges(
448
- "tools",
449
- _make_tools_to_model_edge(tool_node, first_node, structured_output_tools),
450
- [first_node, END],
451
- )
452
- graph.add_conditional_edges(
453
- last_node,
454
- _make_model_to_tools_edge(first_node, structured_output_tools, tool_node),
455
- [first_node, "tools", END],
456
- )
457
- elif last_node == "model_request":
458
- # If no tools, just go to END from model
459
- graph.add_edge(last_node, END)
460
- else:
461
- # If after_model, then need to check for jump_to
462
- _add_middleware_edge(
463
- graph,
464
- f"{middleware_w_after[0].__class__.__name__}.after_model",
465
- END,
466
- first_node,
467
- jump_to=middleware_w_after[0].after_model_jump_to,
468
- )
469
-
470
- # Add middleware edges (same as before)
471
- if middleware_w_before:
472
- for m1, m2 in itertools.pairwise(middleware_w_before):
473
- _add_middleware_edge(
474
- graph,
475
- f"{m1.__class__.__name__}.before_model",
476
- f"{m2.__class__.__name__}.before_model",
477
- first_node,
478
- jump_to=m1.before_model_jump_to,
479
- )
480
- # Go directly to model_request after the last before_model
481
- _add_middleware_edge(
482
- graph,
483
- f"{middleware_w_before[-1].__class__.__name__}.before_model",
484
- "model_request",
485
- first_node,
486
- jump_to=middleware_w_before[-1].before_model_jump_to,
487
- )
488
-
489
- if middleware_w_after:
490
- graph.add_edge("model_request", f"{middleware_w_after[-1].__class__.__name__}.after_model")
491
- for idx in range(len(middleware_w_after) - 1, 0, -1):
492
- m1 = middleware_w_after[idx]
493
- m2 = middleware_w_after[idx - 1]
494
- _add_middleware_edge(
495
- graph,
496
- f"{m1.__class__.__name__}.after_model",
497
- f"{m2.__class__.__name__}.after_model",
498
- first_node,
499
- jump_to=m1.after_model_jump_to,
500
- )
501
-
502
- return graph
503
-
504
-
505
- def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
506
- if jump_to == "model":
507
- return first_node
508
- if jump_to == "end":
509
- return "__end__"
510
- if jump_to == "tools":
511
- return "tools"
512
- return None
513
-
514
-
515
- def _fetch_last_ai_and_tool_messages(
516
- messages: list[AnyMessage],
517
- ) -> tuple[AIMessage, list[ToolMessage]]:
518
- last_ai_index: int
519
- last_ai_message: AIMessage
520
-
521
- for i in range(len(messages) - 1, -1, -1):
522
- if isinstance(messages[i], AIMessage):
523
- last_ai_index = i
524
- last_ai_message = cast("AIMessage", messages[i])
525
- break
526
-
527
- tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)]
528
- return last_ai_message, tool_messages
529
-
530
-
531
- def _make_model_to_tools_edge(
532
- first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
533
- ) -> Callable[[dict[str, Any]], str | list[Send] | None]:
534
- def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None:
535
- if jump_to := state.get("jump_to"):
536
- return _resolve_jump(jump_to, first_node)
537
-
538
- last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
539
- tool_message_ids = [m.tool_call_id for m in tool_messages]
540
-
541
- pending_tool_calls = [
542
- c
543
- for c in last_ai_message.tool_calls
544
- if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools
545
- ]
546
-
547
- if pending_tool_calls:
548
- # imo we should not be injecting state, store here,
549
- # this should be done by the tool node itself ideally but this is a consequence
550
- # of using Send w/ tool calls directly which allows more intuitive interrupt behavior
551
- # largely internal so can be fixed later
552
- pending_tool_calls = [
553
- tool_node.inject_tool_args(call, state, None) for call in pending_tool_calls
554
- ]
555
- return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
556
-
557
- return END
558
-
559
- return model_to_tools
560
-
561
-
562
- def _make_tools_to_model_edge(
563
- tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding]
564
- ) -> Callable[[dict[str, Any]], str | None]:
565
- def tools_to_model(state: dict[str, Any]) -> str | None:
566
- last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
567
-
568
- if all(
569
- tool_node.tools_by_name[c["name"]].return_direct
570
- for c in last_ai_message.tool_calls
571
- if c["name"] in tool_node.tools_by_name
572
- ):
573
- return END
574
-
575
- if any(t.name in structured_output_tools for t in tool_messages):
576
- return END
577
-
578
- return next_node
579
-
580
- return tools_to_model
581
-
582
-
583
- def _add_middleware_edge(
584
- graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
585
- name: str,
586
- default_destination: str,
587
- model_destination: str,
588
- jump_to: list[JumpTo] | None,
589
- ) -> None:
590
- """Add an edge to the graph for a middleware node.
591
-
592
- Args:
593
- graph: The graph to add the edge to.
594
- method: The method to call for the middleware node.
595
- name: The name of the middleware node.
596
- default_destination: The default destination for the edge.
597
- model_destination: The destination for the edge to the model.
598
- jump_to: The conditionally jumpable destinations for the edge.
599
- """
600
- if jump_to:
601
-
602
- def jump_edge(state: dict[str, Any]) -> str:
603
- return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
604
-
605
- destinations = [default_destination]
606
-
607
- if "end" in jump_to:
608
- destinations.append(END)
609
- if "tools" in jump_to:
610
- destinations.append("tools")
611
- if "model" in jump_to and name != model_destination:
612
- destinations.append(model_destination)
613
-
614
- graph.add_conditional_edges(name, jump_edge, destinations)
615
-
616
- else:
617
- graph.add_edge(name, default_destination)