fastmcp 2.14.0__py3-none-any.whl → 2.14.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastmcp/client/client.py +79 -12
- fastmcp/client/sampling/__init__.py +69 -0
- fastmcp/client/sampling/handlers/__init__.py +0 -0
- fastmcp/client/sampling/handlers/anthropic.py +387 -0
- fastmcp/client/sampling/handlers/openai.py +399 -0
- fastmcp/client/tasks.py +0 -63
- fastmcp/client/transports.py +35 -16
- fastmcp/experimental/sampling/handlers/__init__.py +5 -0
- fastmcp/experimental/sampling/handlers/openai.py +4 -169
- fastmcp/prompts/prompt.py +5 -5
- fastmcp/prompts/prompt_manager.py +3 -4
- fastmcp/resources/resource.py +4 -4
- fastmcp/resources/resource_manager.py +9 -14
- fastmcp/resources/template.py +5 -5
- fastmcp/server/auth/auth.py +20 -5
- fastmcp/server/auth/oauth_proxy.py +73 -15
- fastmcp/server/auth/providers/supabase.py +11 -6
- fastmcp/server/context.py +448 -113
- fastmcp/server/dependencies.py +5 -0
- fastmcp/server/elicitation.py +7 -3
- fastmcp/server/middleware/error_handling.py +1 -1
- fastmcp/server/openapi/components.py +2 -4
- fastmcp/server/proxy.py +3 -3
- fastmcp/server/sampling/__init__.py +10 -0
- fastmcp/server/sampling/run.py +301 -0
- fastmcp/server/sampling/sampling_tool.py +108 -0
- fastmcp/server/server.py +84 -78
- fastmcp/server/tasks/converters.py +2 -1
- fastmcp/tools/tool.py +8 -6
- fastmcp/tools/tool_manager.py +5 -7
- fastmcp/utilities/cli.py +23 -43
- fastmcp/utilities/json_schema.py +40 -0
- fastmcp/utilities/openapi/schemas.py +4 -4
- {fastmcp-2.14.0.dist-info → fastmcp-2.14.2.dist-info}/METADATA +8 -3
- {fastmcp-2.14.0.dist-info → fastmcp-2.14.2.dist-info}/RECORD +38 -34
- fastmcp/client/sampling.py +0 -56
- fastmcp/experimental/sampling/handlers/base.py +0 -21
- fastmcp/server/sampling/handler.py +0 -19
- {fastmcp-2.14.0.dist-info → fastmcp-2.14.2.dist-info}/WHEEL +0 -0
- {fastmcp-2.14.0.dist-info → fastmcp-2.14.2.dist-info}/entry_points.txt +0 -0
- {fastmcp-2.14.0.dist-info → fastmcp-2.14.2.dist-info}/licenses/LICENSE +0 -0
fastmcp/server/context.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import copy
|
|
4
|
-
import
|
|
4
|
+
import json
|
|
5
5
|
import logging
|
|
6
6
|
import weakref
|
|
7
|
-
from collections.abc import Generator, Mapping, Sequence
|
|
7
|
+
from collections.abc import Callable, Generator, Mapping, Sequence
|
|
8
8
|
from contextlib import contextmanager
|
|
9
9
|
from contextvars import ContextVar, Token
|
|
10
10
|
from dataclasses import dataclass
|
|
11
11
|
from logging import Logger
|
|
12
|
-
from typing import Any, overload
|
|
12
|
+
from typing import Any, Literal, cast, overload
|
|
13
13
|
|
|
14
14
|
import anyio
|
|
15
15
|
from mcp import LoggingLevel, ServerSession
|
|
@@ -17,25 +17,27 @@ from mcp.server.lowlevel.helper_types import ReadResourceContents
|
|
|
17
17
|
from mcp.server.lowlevel.server import request_ctx
|
|
18
18
|
from mcp.shared.context import RequestContext
|
|
19
19
|
from mcp.types import (
|
|
20
|
-
ClientCapabilities,
|
|
21
20
|
CreateMessageResult,
|
|
21
|
+
CreateMessageResultWithTools,
|
|
22
22
|
GetPromptResult,
|
|
23
|
-
IncludeContext,
|
|
24
|
-
ModelHint,
|
|
25
23
|
ModelPreferences,
|
|
26
24
|
Root,
|
|
27
|
-
SamplingCapability,
|
|
28
25
|
SamplingMessage,
|
|
29
26
|
SamplingMessageContentBlock,
|
|
30
27
|
TextContent,
|
|
28
|
+
ToolChoice,
|
|
29
|
+
ToolResultContent,
|
|
30
|
+
ToolUseContent,
|
|
31
31
|
)
|
|
32
|
-
from mcp.types import
|
|
33
|
-
from mcp.types import
|
|
34
|
-
from mcp.types import
|
|
32
|
+
from mcp.types import Prompt as SDKPrompt
|
|
33
|
+
from mcp.types import Resource as SDKResource
|
|
34
|
+
from mcp.types import Tool as SDKTool
|
|
35
|
+
from pydantic import ValidationError
|
|
35
36
|
from pydantic.networks import AnyUrl
|
|
36
37
|
from starlette.requests import Request
|
|
37
38
|
from typing_extensions import TypeVar
|
|
38
39
|
|
|
40
|
+
from fastmcp import settings
|
|
39
41
|
from fastmcp.server.elicitation import (
|
|
40
42
|
AcceptedElicitation,
|
|
41
43
|
CancelledElicitation,
|
|
@@ -43,8 +45,19 @@ from fastmcp.server.elicitation import (
|
|
|
43
45
|
handle_elicit_accept,
|
|
44
46
|
parse_elicit_response_type,
|
|
45
47
|
)
|
|
48
|
+
from fastmcp.server.sampling import SampleStep, SamplingResult, SamplingTool
|
|
49
|
+
from fastmcp.server.sampling.run import (
|
|
50
|
+
_parse_model_preferences,
|
|
51
|
+
call_sampling_handler,
|
|
52
|
+
determine_handler_mode,
|
|
53
|
+
)
|
|
54
|
+
from fastmcp.server.sampling.run import (
|
|
55
|
+
execute_tools as run_sampling_tools,
|
|
56
|
+
)
|
|
46
57
|
from fastmcp.server.server import FastMCP
|
|
58
|
+
from fastmcp.utilities.json_schema import compress_schema
|
|
47
59
|
from fastmcp.utilities.logging import _clamp_logger, get_logger
|
|
60
|
+
from fastmcp.utilities.types import get_cached_typeadapter
|
|
48
61
|
|
|
49
62
|
logger: Logger = get_logger(name=__name__)
|
|
50
63
|
to_client_logger: Logger = logger.getChild(suffix="to_client")
|
|
@@ -56,8 +69,14 @@ _clamp_logger(logger=to_client_logger, max_level="DEBUG")
|
|
|
56
69
|
|
|
57
70
|
|
|
58
71
|
T = TypeVar("T", default=Any)
|
|
72
|
+
ResultT = TypeVar("ResultT", default=str)
|
|
73
|
+
|
|
74
|
+
# Simplified tool choice type - just the mode string instead of the full MCP object
|
|
75
|
+
ToolChoiceOption = Literal["auto", "required", "none"]
|
|
59
76
|
|
|
60
77
|
_current_context: ContextVar[Context | None] = ContextVar("context", default=None) # type: ignore[assignment]
|
|
78
|
+
|
|
79
|
+
|
|
61
80
|
_flush_lock = anyio.Lock()
|
|
62
81
|
|
|
63
82
|
|
|
@@ -245,7 +264,7 @@ class Context:
|
|
|
245
264
|
related_request_id=self.request_id,
|
|
246
265
|
)
|
|
247
266
|
|
|
248
|
-
async def list_resources(self) -> list[
|
|
267
|
+
async def list_resources(self) -> list[SDKResource]:
|
|
249
268
|
"""List all available resources from the server.
|
|
250
269
|
|
|
251
270
|
Returns:
|
|
@@ -253,7 +272,7 @@ class Context:
|
|
|
253
272
|
"""
|
|
254
273
|
return await self.fastmcp._list_resources_mcp()
|
|
255
274
|
|
|
256
|
-
async def list_prompts(self) -> list[
|
|
275
|
+
async def list_prompts(self) -> list[SDKPrompt]:
|
|
257
276
|
"""List all available prompts from the server.
|
|
258
277
|
|
|
259
278
|
Returns:
|
|
@@ -523,88 +542,343 @@ class Context:
|
|
|
523
542
|
return
|
|
524
543
|
await self.request_context.close_sse_stream()
|
|
525
544
|
|
|
526
|
-
async def
|
|
545
|
+
async def sample_step(
|
|
527
546
|
self,
|
|
528
547
|
messages: str | Sequence[str | SamplingMessage],
|
|
548
|
+
*,
|
|
529
549
|
system_prompt: str | None = None,
|
|
530
|
-
include_context: IncludeContext | None = None,
|
|
531
550
|
temperature: float | None = None,
|
|
532
551
|
max_tokens: int | None = None,
|
|
533
552
|
model_preferences: ModelPreferences | str | list[str] | None = None,
|
|
534
|
-
|
|
553
|
+
tools: Sequence[SamplingTool | Callable[..., Any]] | None = None,
|
|
554
|
+
tool_choice: ToolChoiceOption | str | None = None,
|
|
555
|
+
execute_tools: bool = True,
|
|
556
|
+
mask_error_details: bool | None = None,
|
|
557
|
+
) -> SampleStep:
|
|
535
558
|
"""
|
|
536
|
-
|
|
559
|
+
Make a single LLM sampling call.
|
|
537
560
|
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
561
|
+
This is a stateless function that makes exactly one LLM call and optionally
|
|
562
|
+
executes any requested tools. Use this for fine-grained control over the
|
|
563
|
+
sampling loop.
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
messages: The message(s) to send. Can be a string, list of strings,
|
|
567
|
+
or list of SamplingMessage objects.
|
|
568
|
+
system_prompt: Optional system prompt for the LLM.
|
|
569
|
+
temperature: Optional sampling temperature.
|
|
570
|
+
max_tokens: Maximum tokens to generate. Defaults to 512.
|
|
571
|
+
model_preferences: Optional model preferences.
|
|
572
|
+
tools: Optional list of tools the LLM can use.
|
|
573
|
+
tool_choice: Tool choice mode ("auto", "required", or "none").
|
|
574
|
+
execute_tools: If True (default), execute tool calls and append results
|
|
575
|
+
to history. If False, return immediately with tool_calls available
|
|
576
|
+
in the step for manual execution.
|
|
577
|
+
mask_error_details: If True, mask detailed error messages from tool
|
|
578
|
+
execution. When None (default), uses the global settings value.
|
|
579
|
+
Tools can raise ToolError to bypass masking.
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
SampleStep containing:
|
|
583
|
+
- .response: The raw LLM response
|
|
584
|
+
- .history: Messages including input, assistant response, and tool results
|
|
585
|
+
- .is_tool_use: True if the LLM requested tool execution
|
|
586
|
+
- .tool_calls: List of tool calls (if any)
|
|
587
|
+
- .text: The text content (if any)
|
|
588
|
+
|
|
589
|
+
Example:
|
|
590
|
+
messages = "Research X"
|
|
591
|
+
|
|
592
|
+
while True:
|
|
593
|
+
step = await ctx.sample_step(messages, tools=[search])
|
|
594
|
+
|
|
595
|
+
if not step.is_tool_use:
|
|
596
|
+
print(step.text)
|
|
597
|
+
break
|
|
598
|
+
|
|
599
|
+
# Continue with tool results
|
|
600
|
+
messages = step.history
|
|
541
601
|
"""
|
|
602
|
+
# Convert messages to SamplingMessage objects
|
|
603
|
+
current_messages = _prepare_messages(messages)
|
|
604
|
+
|
|
605
|
+
# Convert tools to SamplingTools
|
|
606
|
+
sampling_tools = _prepare_tools(tools)
|
|
607
|
+
sdk_tools: list[SDKTool] | None = (
|
|
608
|
+
[t._to_sdk_tool() for t in sampling_tools] if sampling_tools else None
|
|
609
|
+
)
|
|
610
|
+
tool_map: dict[str, SamplingTool] = (
|
|
611
|
+
{t.name: t for t in sampling_tools} if sampling_tools else {}
|
|
612
|
+
)
|
|
542
613
|
|
|
543
|
-
|
|
544
|
-
|
|
614
|
+
# Determine whether to use fallback handler or client
|
|
615
|
+
use_fallback = determine_handler_mode(self, bool(sampling_tools))
|
|
545
616
|
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
617
|
+
# Build tool choice
|
|
618
|
+
effective_tool_choice: ToolChoice | None = None
|
|
619
|
+
if tool_choice is not None:
|
|
620
|
+
if tool_choice not in ("auto", "required", "none"):
|
|
621
|
+
raise ValueError(
|
|
622
|
+
f"Invalid tool_choice: {tool_choice!r}. "
|
|
623
|
+
"Must be 'auto', 'required', or 'none'."
|
|
550
624
|
)
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
625
|
+
effective_tool_choice = ToolChoice(
|
|
626
|
+
mode=cast(Literal["auto", "required", "none"], tool_choice)
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
# Effective max_tokens
|
|
630
|
+
effective_max_tokens = max_tokens if max_tokens is not None else 512
|
|
631
|
+
|
|
632
|
+
# Make the LLM call
|
|
633
|
+
if use_fallback:
|
|
634
|
+
response = await call_sampling_handler(
|
|
635
|
+
self,
|
|
636
|
+
current_messages,
|
|
637
|
+
system_prompt=system_prompt,
|
|
638
|
+
temperature=temperature,
|
|
639
|
+
max_tokens=effective_max_tokens,
|
|
640
|
+
model_preferences=model_preferences,
|
|
641
|
+
sdk_tools=sdk_tools,
|
|
642
|
+
tool_choice=effective_tool_choice,
|
|
643
|
+
)
|
|
644
|
+
else:
|
|
645
|
+
response = await self.session.create_message(
|
|
646
|
+
messages=current_messages,
|
|
647
|
+
system_prompt=system_prompt,
|
|
648
|
+
temperature=temperature,
|
|
649
|
+
max_tokens=effective_max_tokens,
|
|
650
|
+
model_preferences=_parse_model_preferences(model_preferences),
|
|
651
|
+
tools=sdk_tools,
|
|
652
|
+
tool_choice=effective_tool_choice,
|
|
653
|
+
related_request_id=self.request_id,
|
|
564
654
|
)
|
|
655
|
+
|
|
656
|
+
# Check if this is a tool use response
|
|
657
|
+
is_tool_use_response = (
|
|
658
|
+
isinstance(response, CreateMessageResultWithTools)
|
|
659
|
+
and response.stopReason == "toolUse"
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
# Always include the assistant response in history
|
|
663
|
+
current_messages.append(
|
|
664
|
+
SamplingMessage(role="assistant", content=response.content)
|
|
565
665
|
)
|
|
566
666
|
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
667
|
+
# If not a tool use, return immediately
|
|
668
|
+
if not is_tool_use_response:
|
|
669
|
+
return SampleStep(response=response, history=current_messages)
|
|
670
|
+
|
|
671
|
+
# If not executing tools, return with assistant message but no tool results
|
|
672
|
+
if not execute_tools:
|
|
673
|
+
return SampleStep(response=response, history=current_messages)
|
|
674
|
+
|
|
675
|
+
# Execute tools and add results to history
|
|
676
|
+
step_tool_calls = _extract_tool_calls(response)
|
|
677
|
+
if step_tool_calls:
|
|
678
|
+
effective_mask = (
|
|
679
|
+
mask_error_details
|
|
680
|
+
if mask_error_details is not None
|
|
681
|
+
else settings.mask_error_details
|
|
682
|
+
)
|
|
683
|
+
tool_results = await run_sampling_tools(
|
|
684
|
+
step_tool_calls, tool_map, mask_error_details=effective_mask
|
|
581
685
|
)
|
|
582
686
|
|
|
583
|
-
if
|
|
584
|
-
|
|
687
|
+
if tool_results:
|
|
688
|
+
current_messages.append(
|
|
689
|
+
SamplingMessage(
|
|
690
|
+
role="user",
|
|
691
|
+
content=tool_results, # type: ignore[arg-type]
|
|
692
|
+
)
|
|
693
|
+
)
|
|
585
694
|
|
|
586
|
-
|
|
587
|
-
return TextContent(text=create_message_result, type="text")
|
|
695
|
+
return SampleStep(response=response, history=current_messages)
|
|
588
696
|
|
|
589
|
-
|
|
590
|
-
|
|
697
|
+
@overload
|
|
698
|
+
async def sample(
|
|
699
|
+
self,
|
|
700
|
+
messages: str | Sequence[str | SamplingMessage],
|
|
701
|
+
*,
|
|
702
|
+
system_prompt: str | None = None,
|
|
703
|
+
temperature: float | None = None,
|
|
704
|
+
max_tokens: int | None = None,
|
|
705
|
+
model_preferences: ModelPreferences | str | list[str] | None = None,
|
|
706
|
+
tools: Sequence[SamplingTool | Callable[..., Any]] | None = None,
|
|
707
|
+
result_type: type[ResultT],
|
|
708
|
+
mask_error_details: bool | None = None,
|
|
709
|
+
) -> SamplingResult[ResultT]:
|
|
710
|
+
"""Overload: With result_type, returns SamplingResult[ResultT]."""
|
|
591
711
|
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
712
|
+
@overload
|
|
713
|
+
async def sample(
|
|
714
|
+
self,
|
|
715
|
+
messages: str | Sequence[str | SamplingMessage],
|
|
716
|
+
*,
|
|
717
|
+
system_prompt: str | None = None,
|
|
718
|
+
temperature: float | None = None,
|
|
719
|
+
max_tokens: int | None = None,
|
|
720
|
+
model_preferences: ModelPreferences | str | list[str] | None = None,
|
|
721
|
+
tools: Sequence[SamplingTool | Callable[..., Any]] | None = None,
|
|
722
|
+
result_type: None = None,
|
|
723
|
+
mask_error_details: bool | None = None,
|
|
724
|
+
) -> SamplingResult[str]:
|
|
725
|
+
"""Overload: Without result_type, returns SamplingResult[str]."""
|
|
726
|
+
|
|
727
|
+
async def sample(
|
|
728
|
+
self,
|
|
729
|
+
messages: str | Sequence[str | SamplingMessage],
|
|
730
|
+
*,
|
|
731
|
+
system_prompt: str | None = None,
|
|
732
|
+
temperature: float | None = None,
|
|
733
|
+
max_tokens: int | None = None,
|
|
734
|
+
model_preferences: ModelPreferences | str | list[str] | None = None,
|
|
735
|
+
tools: Sequence[SamplingTool | Callable[..., Any]] | None = None,
|
|
736
|
+
result_type: type[ResultT] | None = None,
|
|
737
|
+
mask_error_details: bool | None = None,
|
|
738
|
+
) -> SamplingResult[ResultT] | SamplingResult[str]:
|
|
739
|
+
"""
|
|
740
|
+
Send a sampling request to the client and await the response.
|
|
741
|
+
|
|
742
|
+
This method runs to completion automatically. When tools are provided,
|
|
743
|
+
it executes a tool loop: if the LLM returns a tool use request, the tools
|
|
744
|
+
are executed and the results are sent back to the LLM. This continues
|
|
745
|
+
until the LLM provides a final text response.
|
|
746
|
+
|
|
747
|
+
When result_type is specified, a synthetic `final_response` tool is
|
|
748
|
+
created. The LLM calls this tool to provide the structured response,
|
|
749
|
+
which is validated against the result_type and returned as `.result`.
|
|
750
|
+
|
|
751
|
+
For fine-grained control over the sampling loop, use sample_step() instead.
|
|
752
|
+
|
|
753
|
+
Args:
|
|
754
|
+
messages: The message(s) to send. Can be a string, list of strings,
|
|
755
|
+
or list of SamplingMessage objects.
|
|
756
|
+
system_prompt: Optional system prompt for the LLM.
|
|
757
|
+
temperature: Optional sampling temperature.
|
|
758
|
+
max_tokens: Maximum tokens to generate. Defaults to 512.
|
|
759
|
+
model_preferences: Optional model preferences.
|
|
760
|
+
tools: Optional list of tools the LLM can use. Accepts plain
|
|
761
|
+
functions or SamplingTools.
|
|
762
|
+
result_type: Optional type for structured output. When specified,
|
|
763
|
+
a synthetic `final_response` tool is created and the LLM's
|
|
764
|
+
response is validated against this type.
|
|
765
|
+
mask_error_details: If True, mask detailed error messages from tool
|
|
766
|
+
execution. When None (default), uses the global settings value.
|
|
767
|
+
Tools can raise ToolError to bypass masking.
|
|
768
|
+
|
|
769
|
+
Returns:
|
|
770
|
+
SamplingResult[T] containing:
|
|
771
|
+
- .text: The text representation (raw text or JSON for structured)
|
|
772
|
+
- .result: The typed result (str for text, parsed object for structured)
|
|
773
|
+
- .history: All messages exchanged during sampling
|
|
774
|
+
"""
|
|
775
|
+
# Safety limit to prevent infinite loops
|
|
776
|
+
max_iterations = 100
|
|
777
|
+
|
|
778
|
+
# Convert tools to SamplingTools
|
|
779
|
+
sampling_tools = _prepare_tools(tools)
|
|
780
|
+
|
|
781
|
+
# Handle structured output with result_type
|
|
782
|
+
tool_choice: str | None = None
|
|
783
|
+
if result_type is not None and result_type is not str:
|
|
784
|
+
final_response_tool = _create_final_response_tool(result_type)
|
|
785
|
+
sampling_tools = list(sampling_tools) if sampling_tools else []
|
|
786
|
+
sampling_tools.append(final_response_tool)
|
|
787
|
+
|
|
788
|
+
# Always require tool calls when result_type is set - the LLM must
|
|
789
|
+
# eventually call final_response (text responses are not accepted)
|
|
790
|
+
tool_choice = "required"
|
|
791
|
+
|
|
792
|
+
# Convert messages for the loop
|
|
793
|
+
current_messages: str | Sequence[str | SamplingMessage] = messages
|
|
794
|
+
|
|
795
|
+
for _iteration in range(max_iterations):
|
|
796
|
+
step = await self.sample_step(
|
|
797
|
+
messages=current_messages,
|
|
798
|
+
system_prompt=system_prompt,
|
|
799
|
+
temperature=temperature,
|
|
800
|
+
max_tokens=max_tokens,
|
|
801
|
+
model_preferences=model_preferences,
|
|
802
|
+
tools=sampling_tools,
|
|
803
|
+
tool_choice=tool_choice,
|
|
804
|
+
mask_error_details=mask_error_details,
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
# Check for final_response tool call for structured output
|
|
808
|
+
if result_type is not None and result_type is not str and step.is_tool_use:
|
|
809
|
+
for tool_call in step.tool_calls:
|
|
810
|
+
if tool_call.name == "final_response":
|
|
811
|
+
# Validate and return the structured result
|
|
812
|
+
type_adapter = get_cached_typeadapter(result_type)
|
|
813
|
+
|
|
814
|
+
# Unwrap if we wrapped primitives (non-object schemas)
|
|
815
|
+
input_data = tool_call.input
|
|
816
|
+
original_schema = compress_schema(
|
|
817
|
+
type_adapter.json_schema(), prune_titles=True
|
|
818
|
+
)
|
|
819
|
+
if (
|
|
820
|
+
original_schema.get("type") != "object"
|
|
821
|
+
and isinstance(input_data, dict)
|
|
822
|
+
and "value" in input_data
|
|
823
|
+
):
|
|
824
|
+
input_data = input_data["value"]
|
|
825
|
+
|
|
826
|
+
try:
|
|
827
|
+
validated_result = type_adapter.validate_python(input_data)
|
|
828
|
+
text = json.dumps(
|
|
829
|
+
type_adapter.dump_python(validated_result, mode="json")
|
|
830
|
+
)
|
|
831
|
+
return SamplingResult(
|
|
832
|
+
text=text,
|
|
833
|
+
result=validated_result,
|
|
834
|
+
history=step.history,
|
|
835
|
+
)
|
|
836
|
+
except ValidationError as e:
|
|
837
|
+
# Validation failed - add error as tool result
|
|
838
|
+
step.history.append(
|
|
839
|
+
SamplingMessage(
|
|
840
|
+
role="user",
|
|
841
|
+
content=[
|
|
842
|
+
ToolResultContent(
|
|
843
|
+
type="tool_result",
|
|
844
|
+
toolUseId=tool_call.id,
|
|
845
|
+
content=[
|
|
846
|
+
TextContent(
|
|
847
|
+
type="text",
|
|
848
|
+
text=(
|
|
849
|
+
f"Validation error: {e}. "
|
|
850
|
+
"Please try again with valid data."
|
|
851
|
+
),
|
|
852
|
+
)
|
|
853
|
+
],
|
|
854
|
+
isError=True,
|
|
855
|
+
)
|
|
856
|
+
], # type: ignore[arg-type]
|
|
857
|
+
)
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
# If not a tool use response, we're done
|
|
861
|
+
if not step.is_tool_use:
|
|
862
|
+
# For structured output, the LLM must use the final_response tool
|
|
863
|
+
if result_type is not None and result_type is not str:
|
|
864
|
+
raise RuntimeError(
|
|
865
|
+
f"Expected structured output of type {result_type.__name__}, "
|
|
866
|
+
"but the LLM returned a text response instead of calling "
|
|
867
|
+
"the final_response tool."
|
|
868
|
+
)
|
|
869
|
+
return SamplingResult(
|
|
870
|
+
text=step.text,
|
|
871
|
+
result=cast(ResultT, step.text if step.text else ""),
|
|
872
|
+
history=step.history,
|
|
595
873
|
)
|
|
596
874
|
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
max_tokens=max_tokens,
|
|
603
|
-
model_preferences=_parse_model_preferences(model_preferences),
|
|
604
|
-
related_request_id=self.request_id,
|
|
605
|
-
)
|
|
875
|
+
# Continue with the updated history
|
|
876
|
+
current_messages = step.history
|
|
877
|
+
|
|
878
|
+
# After first iteration, reset tool_choice to auto
|
|
879
|
+
tool_choice = None
|
|
606
880
|
|
|
607
|
-
|
|
881
|
+
raise RuntimeError(f"Sampling exceeded maximum iterations ({max_iterations})")
|
|
608
882
|
|
|
609
883
|
@overload
|
|
610
884
|
async def elicit(
|
|
@@ -641,11 +915,12 @@ class Context:
|
|
|
641
915
|
async def elicit(
|
|
642
916
|
self,
|
|
643
917
|
message: str,
|
|
644
|
-
response_type: type[T] | list[str] | None = None,
|
|
918
|
+
response_type: type[T] | list[str] | dict[str, dict[str, str]] | None = None,
|
|
645
919
|
) -> (
|
|
646
920
|
AcceptedElicitation[T]
|
|
647
921
|
| AcceptedElicitation[dict[str, Any]]
|
|
648
922
|
| AcceptedElicitation[str]
|
|
923
|
+
| AcceptedElicitation[list[str]]
|
|
649
924
|
| DeclinedElicitation
|
|
650
925
|
| CancelledElicitation
|
|
651
926
|
):
|
|
@@ -728,47 +1003,6 @@ class Context:
|
|
|
728
1003
|
pass
|
|
729
1004
|
|
|
730
1005
|
|
|
731
|
-
def _parse_model_preferences(
|
|
732
|
-
model_preferences: ModelPreferences | str | list[str] | None,
|
|
733
|
-
) -> ModelPreferences | None:
|
|
734
|
-
"""
|
|
735
|
-
Validates and converts user input for model_preferences into a ModelPreferences object.
|
|
736
|
-
|
|
737
|
-
Args:
|
|
738
|
-
model_preferences (ModelPreferences | str | list[str] | None):
|
|
739
|
-
The model preferences to use. Accepts:
|
|
740
|
-
- ModelPreferences (returns as-is)
|
|
741
|
-
- str (single model hint)
|
|
742
|
-
- list[str] (multiple model hints)
|
|
743
|
-
- None (no preferences)
|
|
744
|
-
|
|
745
|
-
Returns:
|
|
746
|
-
ModelPreferences | None: The parsed ModelPreferences object, or None if not provided.
|
|
747
|
-
|
|
748
|
-
Raises:
|
|
749
|
-
ValueError: If the input is not a supported type or contains invalid values.
|
|
750
|
-
"""
|
|
751
|
-
if model_preferences is None:
|
|
752
|
-
return None
|
|
753
|
-
elif isinstance(model_preferences, ModelPreferences):
|
|
754
|
-
return model_preferences
|
|
755
|
-
elif isinstance(model_preferences, str):
|
|
756
|
-
# Single model hint
|
|
757
|
-
return ModelPreferences(hints=[ModelHint(name=model_preferences)])
|
|
758
|
-
elif isinstance(model_preferences, list):
|
|
759
|
-
# List of model hints (strings)
|
|
760
|
-
if not all(isinstance(h, str) for h in model_preferences):
|
|
761
|
-
raise ValueError(
|
|
762
|
-
"All elements of model_preferences list must be"
|
|
763
|
-
" strings (model name hints)."
|
|
764
|
-
)
|
|
765
|
-
return ModelPreferences(hints=[ModelHint(name=h) for h in model_preferences])
|
|
766
|
-
else:
|
|
767
|
-
raise ValueError(
|
|
768
|
-
"model_preferences must be one of: ModelPreferences, str, list[str], or None."
|
|
769
|
-
)
|
|
770
|
-
|
|
771
|
-
|
|
772
1006
|
async def _log_to_server_and_client(
|
|
773
1007
|
data: LogData,
|
|
774
1008
|
session: ServerSession,
|
|
@@ -795,3 +1029,104 @@ async def _log_to_server_and_client(
|
|
|
795
1029
|
logger=logger_name,
|
|
796
1030
|
related_request_id=related_request_id,
|
|
797
1031
|
)
|
|
1032
|
+
|
|
1033
|
+
|
|
1034
|
+
def _create_final_response_tool(result_type: type) -> SamplingTool:
|
|
1035
|
+
"""Create a synthetic 'final_response' tool for structured output.
|
|
1036
|
+
|
|
1037
|
+
This tool is used to capture structured responses from the LLM.
|
|
1038
|
+
The tool's schema is derived from the result_type.
|
|
1039
|
+
"""
|
|
1040
|
+
type_adapter = get_cached_typeadapter(result_type)
|
|
1041
|
+
schema = type_adapter.json_schema()
|
|
1042
|
+
schema = compress_schema(schema, prune_titles=True)
|
|
1043
|
+
|
|
1044
|
+
# Tool parameters must be object-shaped. Wrap primitives in {"value": <schema>}
|
|
1045
|
+
if schema.get("type") != "object":
|
|
1046
|
+
schema = {
|
|
1047
|
+
"type": "object",
|
|
1048
|
+
"properties": {"value": schema},
|
|
1049
|
+
"required": ["value"],
|
|
1050
|
+
}
|
|
1051
|
+
|
|
1052
|
+
# The fn just returns the input as-is (validation happens in the loop)
|
|
1053
|
+
def final_response(**kwargs: Any) -> dict[str, Any]:
|
|
1054
|
+
return kwargs
|
|
1055
|
+
|
|
1056
|
+
return SamplingTool(
|
|
1057
|
+
name="final_response",
|
|
1058
|
+
description=(
|
|
1059
|
+
"Call this tool to provide your final response. "
|
|
1060
|
+
"Use this when you have completed the task and are ready to return the result."
|
|
1061
|
+
),
|
|
1062
|
+
parameters=schema,
|
|
1063
|
+
fn=final_response,
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
|
|
1067
|
+
def _extract_text_from_content(
|
|
1068
|
+
content: SamplingMessageContentBlock | list[SamplingMessageContentBlock],
|
|
1069
|
+
) -> str | None:
|
|
1070
|
+
"""Extract text from content block(s).
|
|
1071
|
+
|
|
1072
|
+
Returns the text if content is a TextContent or list containing TextContent,
|
|
1073
|
+
otherwise returns None.
|
|
1074
|
+
"""
|
|
1075
|
+
if isinstance(content, list):
|
|
1076
|
+
for block in content:
|
|
1077
|
+
if isinstance(block, TextContent):
|
|
1078
|
+
return block.text
|
|
1079
|
+
return None
|
|
1080
|
+
elif isinstance(content, TextContent):
|
|
1081
|
+
return content.text
|
|
1082
|
+
return None
|
|
1083
|
+
|
|
1084
|
+
|
|
1085
|
+
def _prepare_messages(
|
|
1086
|
+
messages: str | Sequence[str | SamplingMessage],
|
|
1087
|
+
) -> list[SamplingMessage]:
|
|
1088
|
+
"""Convert various message formats to a list of SamplingMessage objects."""
|
|
1089
|
+
if isinstance(messages, str):
|
|
1090
|
+
return [
|
|
1091
|
+
SamplingMessage(
|
|
1092
|
+
content=TextContent(text=messages, type="text"), role="user"
|
|
1093
|
+
)
|
|
1094
|
+
]
|
|
1095
|
+
else:
|
|
1096
|
+
return [
|
|
1097
|
+
SamplingMessage(content=TextContent(text=m, type="text"), role="user")
|
|
1098
|
+
if isinstance(m, str)
|
|
1099
|
+
else m
|
|
1100
|
+
for m in messages
|
|
1101
|
+
]
|
|
1102
|
+
|
|
1103
|
+
|
|
1104
|
+
def _prepare_tools(
|
|
1105
|
+
tools: Sequence[SamplingTool | Callable[..., Any]] | None,
|
|
1106
|
+
) -> list[SamplingTool] | None:
|
|
1107
|
+
"""Convert tools to SamplingTool objects."""
|
|
1108
|
+
if tools is None:
|
|
1109
|
+
return None
|
|
1110
|
+
|
|
1111
|
+
sampling_tools: list[SamplingTool] = []
|
|
1112
|
+
for t in tools:
|
|
1113
|
+
if isinstance(t, SamplingTool):
|
|
1114
|
+
sampling_tools.append(t)
|
|
1115
|
+
elif callable(t):
|
|
1116
|
+
sampling_tools.append(SamplingTool.from_function(t))
|
|
1117
|
+
else:
|
|
1118
|
+
raise TypeError(f"Expected SamplingTool or callable, got {type(t)}")
|
|
1119
|
+
|
|
1120
|
+
return sampling_tools if sampling_tools else None
|
|
1121
|
+
|
|
1122
|
+
|
|
1123
|
+
def _extract_tool_calls(
|
|
1124
|
+
response: CreateMessageResult | CreateMessageResultWithTools,
|
|
1125
|
+
) -> list[ToolUseContent]:
|
|
1126
|
+
"""Extract tool calls from a response."""
|
|
1127
|
+
content = response.content
|
|
1128
|
+
if isinstance(content, list):
|
|
1129
|
+
return [c for c in content if isinstance(c, ToolUseContent)]
|
|
1130
|
+
elif isinstance(content, ToolUseContent):
|
|
1131
|
+
return [content]
|
|
1132
|
+
return []
|