langchain 1.2.3__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/factory.py +55 -40
- langchain/agents/middleware/__init__.py +15 -18
- langchain/agents/middleware/_execution.py +8 -12
- langchain/agents/middleware/_redaction.py +81 -10
- langchain/agents/middleware/context_editing.py +21 -3
- langchain/agents/middleware/file_search.py +1 -1
- langchain/agents/middleware/human_in_the_loop.py +31 -7
- langchain/agents/middleware/model_call_limit.py +1 -1
- langchain/agents/middleware/model_retry.py +8 -1
- langchain/agents/middleware/pii.py +4 -4
- langchain/agents/middleware/shell_tool.py +26 -6
- langchain/agents/middleware/summarization.py +35 -10
- langchain/agents/middleware/todo.py +30 -16
- langchain/agents/middleware/tool_emulator.py +5 -5
- langchain/agents/middleware/tool_retry.py +15 -8
- langchain/agents/middleware/tool_selection.py +45 -11
- langchain/agents/middleware/types.py +110 -43
- langchain/agents/structured_output.py +43 -30
- langchain/chat_models/base.py +25 -17
- {langchain-1.2.3.dist-info → langchain-1.2.4.dist-info}/METADATA +3 -3
- langchain-1.2.4.dist-info/RECORD +36 -0
- langchain-1.2.3.dist-info/RECORD +0 -36
- {langchain-1.2.3.dist-info → langchain-1.2.4.dist-info}/WHEEL +0 -0
- {langchain-1.2.3.dist-info → langchain-1.2.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -78,9 +78,10 @@ class _ModelRequestOverrides(TypedDict, total=False):
|
|
|
78
78
|
system_message: SystemMessage | None
|
|
79
79
|
messages: list[AnyMessage]
|
|
80
80
|
tool_choice: Any | None
|
|
81
|
-
tools: list[BaseTool | dict]
|
|
82
|
-
response_format: ResponseFormat | None
|
|
81
|
+
tools: list[BaseTool | dict[str, Any]]
|
|
82
|
+
response_format: ResponseFormat[Any] | None
|
|
83
83
|
model_settings: dict[str, Any]
|
|
84
|
+
state: AgentState[Any]
|
|
84
85
|
|
|
85
86
|
|
|
86
87
|
@dataclass(init=False)
|
|
@@ -91,9 +92,9 @@ class ModelRequest:
|
|
|
91
92
|
messages: list[AnyMessage] # excluding system message
|
|
92
93
|
system_message: SystemMessage | None
|
|
93
94
|
tool_choice: Any | None
|
|
94
|
-
tools: list[BaseTool | dict]
|
|
95
|
-
response_format: ResponseFormat | None
|
|
96
|
-
state: AgentState
|
|
95
|
+
tools: list[BaseTool | dict[str, Any]]
|
|
96
|
+
response_format: ResponseFormat[Any] | None
|
|
97
|
+
state: AgentState[Any]
|
|
97
98
|
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
|
98
99
|
model_settings: dict[str, Any] = field(default_factory=dict)
|
|
99
100
|
|
|
@@ -105,9 +106,9 @@ class ModelRequest:
|
|
|
105
106
|
system_message: SystemMessage | None = None,
|
|
106
107
|
system_prompt: str | None = None,
|
|
107
108
|
tool_choice: Any | None = None,
|
|
108
|
-
tools: list[BaseTool | dict] | None = None,
|
|
109
|
-
response_format: ResponseFormat | None = None,
|
|
110
|
-
state: AgentState | None = None,
|
|
109
|
+
tools: list[BaseTool | dict[str, Any]] | None = None,
|
|
110
|
+
response_format: ResponseFormat[Any] | None = None,
|
|
111
|
+
state: AgentState[Any] | None = None,
|
|
111
112
|
runtime: Runtime[ContextT] | None = None,
|
|
112
113
|
model_settings: dict[str, Any] | None = None,
|
|
113
114
|
) -> None:
|
|
@@ -124,6 +125,9 @@ class ModelRequest:
|
|
|
124
125
|
model_settings: Additional model settings.
|
|
125
126
|
system_message: System message instance (preferred).
|
|
126
127
|
system_prompt: System prompt string (deprecated, converted to SystemMessage).
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
ValueError: If both `system_prompt` and `system_message` are provided.
|
|
127
131
|
"""
|
|
128
132
|
# Handle system_prompt/system_message conversion and validation
|
|
129
133
|
if system_prompt is not None and system_message is not None:
|
|
@@ -210,6 +214,7 @@ class ModelRequest:
|
|
|
210
214
|
- `tools`: `list` of available tools
|
|
211
215
|
- `response_format`: Response format specification
|
|
212
216
|
- `model_settings`: Additional model settings
|
|
217
|
+
- `state`: Agent state dictionary
|
|
213
218
|
|
|
214
219
|
Returns:
|
|
215
220
|
New `ModelRequest` instance with specified overrides applied.
|
|
@@ -239,6 +244,9 @@ class ModelRequest:
|
|
|
239
244
|
system_message=SystemMessage(content="New instructions"),
|
|
240
245
|
)
|
|
241
246
|
```
|
|
247
|
+
|
|
248
|
+
Raises:
|
|
249
|
+
ValueError: If both `system_prompt` and `system_message` are provided.
|
|
242
250
|
"""
|
|
243
251
|
# Handle system_prompt/system_message conversion
|
|
244
252
|
if "system_prompt" in overrides and "system_message" in overrides:
|
|
@@ -246,7 +254,7 @@ class ModelRequest:
|
|
|
246
254
|
raise ValueError(msg)
|
|
247
255
|
|
|
248
256
|
if "system_prompt" in overrides:
|
|
249
|
-
system_prompt = cast("str", overrides.pop("system_prompt")) # type: ignore[typeddict-item]
|
|
257
|
+
system_prompt = cast("str | None", overrides.pop("system_prompt")) # type: ignore[typeddict-item]
|
|
250
258
|
if system_prompt is None:
|
|
251
259
|
overrides["system_message"] = None
|
|
252
260
|
else:
|
|
@@ -313,7 +321,7 @@ class AgentState(TypedDict, Generic[ResponseT]):
|
|
|
313
321
|
class _InputAgentState(TypedDict): # noqa: PYI049
|
|
314
322
|
"""Input state schema for the agent."""
|
|
315
323
|
|
|
316
|
-
messages: Required[Annotated[list[AnyMessage | dict], add_messages]]
|
|
324
|
+
messages: Required[Annotated[list[AnyMessage | dict[str, Any]], add_messages]]
|
|
317
325
|
|
|
318
326
|
|
|
319
327
|
class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049
|
|
@@ -323,9 +331,13 @@ class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049
|
|
|
323
331
|
structured_response: NotRequired[ResponseT]
|
|
324
332
|
|
|
325
333
|
|
|
326
|
-
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
|
|
327
|
-
StateT_co = TypeVar("StateT_co", bound=AgentState, default=AgentState, covariant=True)
|
|
328
|
-
StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
|
|
334
|
+
StateT = TypeVar("StateT", bound=AgentState[Any], default=AgentState[Any])
|
|
335
|
+
StateT_co = TypeVar("StateT_co", bound=AgentState[Any], default=AgentState[Any], covariant=True)
|
|
336
|
+
StateT_contra = TypeVar("StateT_contra", bound=AgentState[Any], contravariant=True)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class _DefaultAgentState(AgentState[Any]):
|
|
340
|
+
"""AgentMiddleware default state."""
|
|
329
341
|
|
|
330
342
|
|
|
331
343
|
class AgentMiddleware(Generic[StateT, ContextT]):
|
|
@@ -335,7 +347,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
335
347
|
between steps in the main agent loop.
|
|
336
348
|
"""
|
|
337
349
|
|
|
338
|
-
state_schema: type[StateT] = cast("type[StateT]",
|
|
350
|
+
state_schema: type[StateT] = cast("type[StateT]", _DefaultAgentState)
|
|
339
351
|
"""The schema for state passed to the middleware nodes."""
|
|
340
352
|
|
|
341
353
|
tools: Sequence[BaseTool]
|
|
@@ -352,35 +364,74 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
352
364
|
def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
353
365
|
"""Logic to run before the agent execution starts.
|
|
354
366
|
|
|
355
|
-
|
|
367
|
+
Args:
|
|
368
|
+
state: The current agent state.
|
|
369
|
+
runtime: The runtime context.
|
|
370
|
+
|
|
371
|
+
Returns:
|
|
372
|
+
Agent state updates to apply before agent execution.
|
|
356
373
|
"""
|
|
357
374
|
|
|
358
375
|
async def abefore_agent(
|
|
359
376
|
self, state: StateT, runtime: Runtime[ContextT]
|
|
360
377
|
) -> dict[str, Any] | None:
|
|
361
|
-
"""Async logic to run before the agent execution starts.
|
|
378
|
+
"""Async logic to run before the agent execution starts.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
state: The current agent state.
|
|
382
|
+
runtime: The runtime context.
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
Agent state updates to apply before agent execution.
|
|
386
|
+
"""
|
|
362
387
|
|
|
363
388
|
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
364
389
|
"""Logic to run before the model is called.
|
|
365
390
|
|
|
366
|
-
|
|
391
|
+
Args:
|
|
392
|
+
state: The current agent state.
|
|
393
|
+
runtime: The runtime context.
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
Agent state updates to apply before model call.
|
|
367
397
|
"""
|
|
368
398
|
|
|
369
399
|
async def abefore_model(
|
|
370
400
|
self, state: StateT, runtime: Runtime[ContextT]
|
|
371
401
|
) -> dict[str, Any] | None:
|
|
372
|
-
"""Async logic to run before the model is called.
|
|
402
|
+
"""Async logic to run before the model is called.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
state: The agent state.
|
|
406
|
+
runtime: The runtime context.
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
Agent state updates to apply before model call.
|
|
410
|
+
"""
|
|
373
411
|
|
|
374
412
|
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
375
413
|
"""Logic to run after the model is called.
|
|
376
414
|
|
|
377
|
-
|
|
415
|
+
Args:
|
|
416
|
+
state: The current agent state.
|
|
417
|
+
runtime: The runtime context.
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
Agent state updates to apply after model call.
|
|
378
421
|
"""
|
|
379
422
|
|
|
380
423
|
async def aafter_model(
|
|
381
424
|
self, state: StateT, runtime: Runtime[ContextT]
|
|
382
425
|
) -> dict[str, Any] | None:
|
|
383
|
-
"""Async logic to run after the model is called.
|
|
426
|
+
"""Async logic to run after the model is called.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
state: The current agent state.
|
|
430
|
+
runtime: The runtime context.
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
Agent state updates to apply after model call.
|
|
434
|
+
"""
|
|
384
435
|
|
|
385
436
|
def wrap_model_call(
|
|
386
437
|
self,
|
|
@@ -408,7 +459,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
408
459
|
Can skip calling it to short-circuit.
|
|
409
460
|
|
|
410
461
|
Returns:
|
|
411
|
-
|
|
462
|
+
The model call result.
|
|
412
463
|
|
|
413
464
|
Examples:
|
|
414
465
|
!!! example "Retry on error"
|
|
@@ -502,7 +553,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
502
553
|
Can skip calling it to short-circuit.
|
|
503
554
|
|
|
504
555
|
Returns:
|
|
505
|
-
|
|
556
|
+
The model call result.
|
|
506
557
|
|
|
507
558
|
Examples:
|
|
508
559
|
!!! example "Retry on error"
|
|
@@ -530,18 +581,34 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
530
581
|
raise NotImplementedError(msg)
|
|
531
582
|
|
|
532
583
|
def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
533
|
-
"""Logic to run after the agent execution completes.
|
|
584
|
+
"""Logic to run after the agent execution completes.
|
|
585
|
+
|
|
586
|
+
Args:
|
|
587
|
+
state: The current agent state.
|
|
588
|
+
runtime: The runtime context.
|
|
589
|
+
|
|
590
|
+
Returns:
|
|
591
|
+
Agent state updates to apply after agent execution.
|
|
592
|
+
"""
|
|
534
593
|
|
|
535
594
|
async def aafter_agent(
|
|
536
595
|
self, state: StateT, runtime: Runtime[ContextT]
|
|
537
596
|
) -> dict[str, Any] | None:
|
|
538
|
-
"""Async logic to run after the agent execution completes.
|
|
597
|
+
"""Async logic to run after the agent execution completes.
|
|
598
|
+
|
|
599
|
+
Args:
|
|
600
|
+
state: The current agent state.
|
|
601
|
+
runtime: The runtime context.
|
|
602
|
+
|
|
603
|
+
Returns:
|
|
604
|
+
Agent state updates to apply after agent execution.
|
|
605
|
+
"""
|
|
539
606
|
|
|
540
607
|
def wrap_tool_call(
|
|
541
608
|
self,
|
|
542
609
|
request: ToolCallRequest,
|
|
543
|
-
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
544
|
-
) -> ToolMessage | Command:
|
|
610
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
|
611
|
+
) -> ToolMessage | Command[Any]:
|
|
545
612
|
"""Intercept tool execution for retries, monitoring, or modification.
|
|
546
613
|
|
|
547
614
|
Async version is `awrap_tool_call`
|
|
@@ -622,8 +689,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
622
689
|
async def awrap_tool_call(
|
|
623
690
|
self,
|
|
624
691
|
request: ToolCallRequest,
|
|
625
|
-
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
626
|
-
) -> ToolMessage | Command:
|
|
692
|
+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
|
693
|
+
) -> ToolMessage | Command[Any]:
|
|
627
694
|
"""Intercept and control async tool execution via handler callback.
|
|
628
695
|
|
|
629
696
|
The handler callback executes the tool call and returns a `ToolMessage` or
|
|
@@ -694,7 +761,7 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
|
694
761
|
|
|
695
762
|
def __call__(
|
|
696
763
|
self, state: StateT_contra, runtime: Runtime[ContextT]
|
|
697
|
-
) -> dict[str, Any] | Command | None | Awaitable[dict[str, Any] | Command | None]:
|
|
764
|
+
) -> dict[str, Any] | Command[Any] | None | Awaitable[dict[str, Any] | Command[Any] | None]:
|
|
698
765
|
"""Perform some logic with the state and runtime."""
|
|
699
766
|
...
|
|
700
767
|
|
|
@@ -735,8 +802,8 @@ class _CallableReturningToolResponse(Protocol):
|
|
|
735
802
|
def __call__(
|
|
736
803
|
self,
|
|
737
804
|
request: ToolCallRequest,
|
|
738
|
-
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
739
|
-
) -> ToolMessage | Command:
|
|
805
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
|
806
|
+
) -> ToolMessage | Command[Any]:
|
|
740
807
|
"""Intercept tool execution via handler callback."""
|
|
741
808
|
...
|
|
742
809
|
|
|
@@ -918,7 +985,7 @@ def before_model(
|
|
|
918
985
|
_self: AgentMiddleware[StateT, ContextT],
|
|
919
986
|
state: StateT,
|
|
920
987
|
runtime: Runtime[ContextT],
|
|
921
|
-
) -> dict[str, Any] | Command | None:
|
|
988
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
922
989
|
return await func(state, runtime) # type: ignore[misc]
|
|
923
990
|
|
|
924
991
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -943,7 +1010,7 @@ def before_model(
|
|
|
943
1010
|
_self: AgentMiddleware[StateT, ContextT],
|
|
944
1011
|
state: StateT,
|
|
945
1012
|
runtime: Runtime[ContextT],
|
|
946
|
-
) -> dict[str, Any] | Command | None:
|
|
1013
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
947
1014
|
return func(state, runtime) # type: ignore[return-value]
|
|
948
1015
|
|
|
949
1016
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -1078,7 +1145,7 @@ def after_model(
|
|
|
1078
1145
|
_self: AgentMiddleware[StateT, ContextT],
|
|
1079
1146
|
state: StateT,
|
|
1080
1147
|
runtime: Runtime[ContextT],
|
|
1081
|
-
) -> dict[str, Any] | Command | None:
|
|
1148
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
1082
1149
|
return await func(state, runtime) # type: ignore[misc]
|
|
1083
1150
|
|
|
1084
1151
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -1101,7 +1168,7 @@ def after_model(
|
|
|
1101
1168
|
_self: AgentMiddleware[StateT, ContextT],
|
|
1102
1169
|
state: StateT,
|
|
1103
1170
|
runtime: Runtime[ContextT],
|
|
1104
|
-
) -> dict[str, Any] | Command | None:
|
|
1171
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
1105
1172
|
return func(state, runtime) # type: ignore[return-value]
|
|
1106
1173
|
|
|
1107
1174
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -1269,7 +1336,7 @@ def before_agent(
|
|
|
1269
1336
|
_self: AgentMiddleware[StateT, ContextT],
|
|
1270
1337
|
state: StateT,
|
|
1271
1338
|
runtime: Runtime[ContextT],
|
|
1272
|
-
) -> dict[str, Any] | Command | None:
|
|
1339
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
1273
1340
|
return await func(state, runtime) # type: ignore[misc]
|
|
1274
1341
|
|
|
1275
1342
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -1294,7 +1361,7 @@ def before_agent(
|
|
|
1294
1361
|
_self: AgentMiddleware[StateT, ContextT],
|
|
1295
1362
|
state: StateT,
|
|
1296
1363
|
runtime: Runtime[ContextT],
|
|
1297
|
-
) -> dict[str, Any] | Command | None:
|
|
1364
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
1298
1365
|
return func(state, runtime) # type: ignore[return-value]
|
|
1299
1366
|
|
|
1300
1367
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -1430,7 +1497,7 @@ def after_agent(
|
|
|
1430
1497
|
_self: AgentMiddleware[StateT, ContextT],
|
|
1431
1498
|
state: StateT,
|
|
1432
1499
|
runtime: Runtime[ContextT],
|
|
1433
|
-
) -> dict[str, Any] | Command | None:
|
|
1500
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
1434
1501
|
return await func(state, runtime) # type: ignore[misc]
|
|
1435
1502
|
|
|
1436
1503
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -1453,7 +1520,7 @@ def after_agent(
|
|
|
1453
1520
|
_self: AgentMiddleware[StateT, ContextT],
|
|
1454
1521
|
state: StateT,
|
|
1455
1522
|
runtime: Runtime[ContextT],
|
|
1456
|
-
) -> dict[str, Any] | Command | None:
|
|
1523
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
1457
1524
|
return func(state, runtime) # type: ignore[return-value]
|
|
1458
1525
|
|
|
1459
1526
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -1901,8 +1968,8 @@ def wrap_tool_call(
|
|
|
1901
1968
|
async def async_wrapped(
|
|
1902
1969
|
_self: AgentMiddleware,
|
|
1903
1970
|
request: ToolCallRequest,
|
|
1904
|
-
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
1905
|
-
) -> ToolMessage | Command:
|
|
1971
|
+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
|
1972
|
+
) -> ToolMessage | Command[Any]:
|
|
1906
1973
|
return await func(request, handler) # type: ignore[arg-type,misc]
|
|
1907
1974
|
|
|
1908
1975
|
middleware_name = name or cast(
|
|
@@ -1922,8 +1989,8 @@ def wrap_tool_call(
|
|
|
1922
1989
|
def wrapped(
|
|
1923
1990
|
_self: AgentMiddleware,
|
|
1924
1991
|
request: ToolCallRequest,
|
|
1925
|
-
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
1926
|
-
) -> ToolMessage | Command:
|
|
1992
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
|
1993
|
+
) -> ToolMessage | Command[Any]:
|
|
1927
1994
|
return func(request, handler)
|
|
1928
1995
|
|
|
1929
1996
|
middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))
|
|
@@ -75,7 +75,7 @@ class StructuredOutputValidationError(StructuredOutputError):
|
|
|
75
75
|
|
|
76
76
|
|
|
77
77
|
def _parse_with_schema(
|
|
78
|
-
schema: type[SchemaT] | dict, schema_kind: SchemaKind, data: dict[str, Any]
|
|
78
|
+
schema: type[SchemaT] | dict[str, Any], schema_kind: SchemaKind, data: dict[str, Any]
|
|
79
79
|
) -> Any:
|
|
80
80
|
"""Parse data using for any supported schema type.
|
|
81
81
|
|
|
@@ -106,21 +106,21 @@ def _parse_with_schema(
|
|
|
106
106
|
class _SchemaSpec(Generic[SchemaT]):
|
|
107
107
|
"""Describes a structured output schema."""
|
|
108
108
|
|
|
109
|
-
schema: type[SchemaT]
|
|
109
|
+
schema: type[SchemaT] | dict[str, Any]
|
|
110
110
|
"""The schema for the response, can be a Pydantic model, `dataclass`, `TypedDict`,
|
|
111
111
|
or JSON schema dict."""
|
|
112
112
|
|
|
113
113
|
name: str
|
|
114
114
|
"""Name of the schema, used for tool calling.
|
|
115
115
|
|
|
116
|
-
If not provided, the name will be the
|
|
117
|
-
JSON
|
|
116
|
+
If not provided, the name will be the class name for models/dataclasses/TypedDicts,
|
|
117
|
+
or the `title` field for JSON schemas. Falls back to a generated name if unavailable.
|
|
118
118
|
"""
|
|
119
119
|
|
|
120
120
|
description: str
|
|
121
121
|
"""Custom description of the schema.
|
|
122
122
|
|
|
123
|
-
If not provided,
|
|
123
|
+
If not provided, will use the model's docstring.
|
|
124
124
|
"""
|
|
125
125
|
|
|
126
126
|
schema_kind: SchemaKind
|
|
@@ -134,13 +134,23 @@ class _SchemaSpec(Generic[SchemaT]):
|
|
|
134
134
|
|
|
135
135
|
def __init__(
|
|
136
136
|
self,
|
|
137
|
-
schema: type[SchemaT],
|
|
137
|
+
schema: type[SchemaT] | dict[str, Any],
|
|
138
138
|
*,
|
|
139
139
|
name: str | None = None,
|
|
140
140
|
description: str | None = None,
|
|
141
141
|
strict: bool | None = None,
|
|
142
142
|
) -> None:
|
|
143
|
-
"""Initialize SchemaSpec with schema and optional parameters.
|
|
143
|
+
"""Initialize `SchemaSpec` with schema and optional parameters.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
schema: Schema to describe.
|
|
147
|
+
name: Optional name for the schema.
|
|
148
|
+
description: Optional description for the schema.
|
|
149
|
+
strict: Whether to enforce strict validation of the schema.
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
ValueError: If the schema type is unsupported.
|
|
153
|
+
"""
|
|
144
154
|
self.schema = schema
|
|
145
155
|
|
|
146
156
|
if name:
|
|
@@ -182,10 +192,10 @@ class _SchemaSpec(Generic[SchemaT]):
|
|
|
182
192
|
class ToolStrategy(Generic[SchemaT]):
|
|
183
193
|
"""Use a tool calling strategy for model responses."""
|
|
184
194
|
|
|
185
|
-
schema: type[SchemaT]
|
|
195
|
+
schema: type[SchemaT] | UnionType | dict[str, Any]
|
|
186
196
|
"""Schema for the tool calls."""
|
|
187
197
|
|
|
188
|
-
schema_specs: list[_SchemaSpec[
|
|
198
|
+
schema_specs: list[_SchemaSpec[Any]]
|
|
189
199
|
"""Schema specs for the tool calls."""
|
|
190
200
|
|
|
191
201
|
tool_message_content: str | None
|
|
@@ -208,7 +218,7 @@ class ToolStrategy(Generic[SchemaT]):
|
|
|
208
218
|
|
|
209
219
|
def __init__(
|
|
210
220
|
self,
|
|
211
|
-
schema: type[SchemaT],
|
|
221
|
+
schema: type[SchemaT] | UnionType | dict[str, Any],
|
|
212
222
|
*,
|
|
213
223
|
tool_message_content: str | None = None,
|
|
214
224
|
handle_errors: bool
|
|
@@ -247,7 +257,7 @@ class ToolStrategy(Generic[SchemaT]):
|
|
|
247
257
|
class ProviderStrategy(Generic[SchemaT]):
|
|
248
258
|
"""Use the model provider's native structured output method."""
|
|
249
259
|
|
|
250
|
-
schema: type[SchemaT]
|
|
260
|
+
schema: type[SchemaT] | dict[str, Any]
|
|
251
261
|
"""Schema for native mode."""
|
|
252
262
|
|
|
253
263
|
schema_spec: _SchemaSpec[SchemaT]
|
|
@@ -255,7 +265,7 @@ class ProviderStrategy(Generic[SchemaT]):
|
|
|
255
265
|
|
|
256
266
|
def __init__(
|
|
257
267
|
self,
|
|
258
|
-
schema: type[SchemaT],
|
|
268
|
+
schema: type[SchemaT] | dict[str, Any],
|
|
259
269
|
*,
|
|
260
270
|
strict: bool | None = None,
|
|
261
271
|
) -> None:
|
|
@@ -269,7 +279,11 @@ class ProviderStrategy(Generic[SchemaT]):
|
|
|
269
279
|
self.schema_spec = _SchemaSpec(schema, strict=strict)
|
|
270
280
|
|
|
271
281
|
def to_model_kwargs(self) -> dict[str, Any]:
|
|
272
|
-
"""Convert to kwargs to bind to a model to force structured output.
|
|
282
|
+
"""Convert to kwargs to bind to a model to force structured output.
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
The kwargs to bind to a model.
|
|
286
|
+
"""
|
|
273
287
|
# OpenAI:
|
|
274
288
|
# - see https://platform.openai.com/docs/guides/structured-outputs
|
|
275
289
|
json_schema: dict[str, Any] = {
|
|
@@ -295,7 +309,7 @@ class OutputToolBinding(Generic[SchemaT]):
|
|
|
295
309
|
and the corresponding tool implementation used by the tools strategy.
|
|
296
310
|
"""
|
|
297
311
|
|
|
298
|
-
schema: type[SchemaT]
|
|
312
|
+
schema: type[SchemaT] | dict[str, Any]
|
|
299
313
|
"""The original schema provided for structured output
|
|
300
314
|
(Pydantic model, dataclass, TypedDict, or JSON schema dict)."""
|
|
301
315
|
|
|
@@ -349,7 +363,7 @@ class ProviderStrategyBinding(Generic[SchemaT]):
|
|
|
349
363
|
its type classification, and parsing logic for provider-enforced JSON.
|
|
350
364
|
"""
|
|
351
365
|
|
|
352
|
-
schema: type[SchemaT]
|
|
366
|
+
schema: type[SchemaT] | dict[str, Any]
|
|
353
367
|
"""The original schema provided for structured output
|
|
354
368
|
(Pydantic model, `dataclass`, `TypedDict`, or JSON schema dict)."""
|
|
355
369
|
|
|
@@ -399,7 +413,8 @@ class ProviderStrategyBinding(Generic[SchemaT]):
|
|
|
399
413
|
# Parse according to schema
|
|
400
414
|
return _parse_with_schema(self.schema, self.schema_kind, data)
|
|
401
415
|
|
|
402
|
-
|
|
416
|
+
@staticmethod
|
|
417
|
+
def _extract_text_content_from_message(message: AIMessage) -> str:
|
|
403
418
|
"""Extract text content from an AIMessage.
|
|
404
419
|
|
|
405
420
|
Args:
|
|
@@ -411,29 +426,27 @@ class ProviderStrategyBinding(Generic[SchemaT]):
|
|
|
411
426
|
content = message.content
|
|
412
427
|
if isinstance(content, str):
|
|
413
428
|
return content
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
if
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
return "".join(parts)
|
|
425
|
-
return str(content)
|
|
429
|
+
parts: list[str] = []
|
|
430
|
+
for c in content:
|
|
431
|
+
if isinstance(c, dict):
|
|
432
|
+
if c.get("type") == "text" and "text" in c:
|
|
433
|
+
parts.append(str(c["text"]))
|
|
434
|
+
elif "content" in c and isinstance(c["content"], str):
|
|
435
|
+
parts.append(c["content"])
|
|
436
|
+
else:
|
|
437
|
+
parts.append(str(c))
|
|
438
|
+
return "".join(parts)
|
|
426
439
|
|
|
427
440
|
|
|
428
441
|
class AutoStrategy(Generic[SchemaT]):
|
|
429
442
|
"""Automatically select the best strategy for structured output."""
|
|
430
443
|
|
|
431
|
-
schema: type[SchemaT]
|
|
444
|
+
schema: type[SchemaT] | dict[str, Any]
|
|
432
445
|
"""Schema for automatic mode."""
|
|
433
446
|
|
|
434
447
|
def __init__(
|
|
435
448
|
self,
|
|
436
|
-
schema: type[SchemaT],
|
|
449
|
+
schema: type[SchemaT] | dict[str, Any],
|
|
437
450
|
) -> None:
|
|
438
451
|
"""Initialize AutoStrategy with schema."""
|
|
439
452
|
self.schema = schema
|