fastmcp 2.2.6__py3-none-any.whl → 2.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fastmcp/tools/tool.py CHANGED
@@ -6,13 +6,19 @@ from collections.abc import Callable
6
6
  from typing import TYPE_CHECKING, Annotated, Any
7
7
 
8
8
  import pydantic_core
9
- from mcp.types import EmbeddedResource, ImageContent, TextContent
9
+ from mcp.types import EmbeddedResource, ImageContent, TextContent, ToolAnnotations
10
10
  from mcp.types import Tool as MCPTool
11
11
  from pydantic import BaseModel, BeforeValidator, Field
12
12
 
13
13
  from fastmcp.exceptions import ToolError
14
- from fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
15
- from fastmcp.utilities.types import Image, _convert_set_defaults
14
+ from fastmcp.utilities.json_schema import prune_params
15
+ from fastmcp.utilities.logging import get_logger
16
+ from fastmcp.utilities.types import (
17
+ Image,
18
+ _convert_set_defaults,
19
+ find_kwarg_by_type,
20
+ get_cached_typeadapter,
21
+ )
16
22
 
17
23
  if TYPE_CHECKING:
18
24
  from mcp.server.session import ServerSessionT
@@ -20,6 +26,12 @@ if TYPE_CHECKING:
20
26
 
21
27
  from fastmcp.server import Context
22
28
 
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ def default_serializer(data: Any) -> str:
33
+ return pydantic_core.to_json(data, fallback=str, indent=2).decode()
34
+
23
35
 
24
36
  class Tool(BaseModel):
25
37
  """Internal tool registration info."""
@@ -28,17 +40,18 @@ class Tool(BaseModel):
28
40
  name: str = Field(description="Name of the tool")
29
41
  description: str = Field(description="Description of what the tool does")
30
42
  parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
31
- fn_metadata: FuncMetadata = Field(
32
- description="Metadata about the function including a pydantic model for tool"
33
- " arguments"
34
- )
35
- is_async: bool = Field(description="Whether the tool is async")
36
43
  context_kwarg: str | None = Field(
37
44
  None, description="Name of the kwarg that should receive context"
38
45
  )
39
46
  tags: Annotated[set[str], BeforeValidator(_convert_set_defaults)] = Field(
40
47
  default_factory=set, description="Tags for the tool"
41
48
  )
49
+ annotations: ToolAnnotations | None = Field(
50
+ None, description="Additional annotations about the tool"
51
+ )
52
+ serializer: Callable[[Any], str] | None = Field(
53
+ None, description="Optional custom serializer for tool results"
54
+ )
42
55
 
43
56
  @classmethod
44
57
  def from_function(
@@ -48,50 +61,44 @@ class Tool(BaseModel):
48
61
  description: str | None = None,
49
62
  context_kwarg: str | None = None,
50
63
  tags: set[str] | None = None,
64
+ annotations: ToolAnnotations | None = None,
65
+ serializer: Callable[[Any], str] | None = None,
51
66
  ) -> Tool:
52
67
  """Create a Tool from a function."""
53
68
  from fastmcp import Context
54
69
 
70
+ # Reject functions with *args or **kwargs
71
+ sig = inspect.signature(fn)
72
+ for param in sig.parameters.values():
73
+ if param.kind == inspect.Parameter.VAR_POSITIONAL:
74
+ raise ValueError("Functions with *args are not supported as tools")
75
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
76
+ raise ValueError("Functions with **kwargs are not supported as tools")
77
+
55
78
  func_name = name or fn.__name__
56
79
 
57
80
  if func_name == "<lambda>":
58
81
  raise ValueError("You must provide a name for lambda functions")
59
82
 
60
83
  func_doc = description or fn.__doc__ or ""
61
- is_async = inspect.iscoroutinefunction(fn)
84
+
85
+ type_adapter = get_cached_typeadapter(fn)
86
+ schema = type_adapter.json_schema()
62
87
 
63
88
  if context_kwarg is None:
64
- if inspect.ismethod(fn) and hasattr(fn, "__func__"):
65
- sig = inspect.signature(fn.__func__)
66
- else:
67
- sig = inspect.signature(fn)
68
- for param_name, param in sig.parameters.items():
69
- if param.annotation is Context:
70
- context_kwarg = param_name
71
- break
72
-
73
- # Use callable typing to ensure fn is treated as a callable despite being a classmethod
74
- fn_callable: Callable[..., Any] = fn
75
- func_arg_metadata = func_metadata(
76
- fn_callable,
77
- skip_names=[context_kwarg] if context_kwarg is not None else [],
78
- )
79
- try:
80
- parameters = func_arg_metadata.arg_model.model_json_schema()
81
- except Exception as e:
82
- raise TypeError(
83
- f'Unable to parse parameters for function "{fn.__name__}": {e}'
84
- ) from e
89
+ context_kwarg = find_kwarg_by_type(fn, kwarg_type=Context)
90
+ if context_kwarg:
91
+ schema = prune_params(schema, params=[context_kwarg])
85
92
 
86
93
  return cls(
87
- fn=fn_callable,
94
+ fn=fn,
88
95
  name=func_name,
89
96
  description=func_doc,
90
- parameters=parameters,
91
- fn_metadata=func_arg_metadata,
92
- is_async=is_async,
97
+ parameters=schema,
93
98
  context_kwarg=context_kwarg,
94
99
  tags=tags or set(),
100
+ annotations=annotations,
101
+ serializer=serializer,
95
102
  )
96
103
 
97
104
  async def run(
@@ -101,18 +108,31 @@ class Tool(BaseModel):
101
108
  ) -> list[TextContent | ImageContent | EmbeddedResource]:
102
109
  """Run the tool with arguments."""
103
110
  try:
104
- pass_args = (
105
- {self.context_kwarg: context}
106
- if self.context_kwarg is not None
107
- else None
111
+ injected_args = (
112
+ {self.context_kwarg: context} if self.context_kwarg is not None else {}
108
113
  )
109
- result = await self.fn_metadata.call_fn_with_arg_validation(
110
- fn=self.fn,
111
- fn_is_async=self.is_async,
112
- arguments_to_validate=arguments,
113
- arguments_to_pass_directly=pass_args,
114
- )
115
- return _convert_to_content(result)
114
+
115
+ parsed_args = arguments.copy()
116
+
117
+ # Pre-parse data from JSON in order to handle cases like `["a", "b", "c"]`
118
+ # being passed in as JSON inside a string rather than an actual list.
119
+ #
120
+ # Claude desktop is prone to this - in fact it seems incapable of NOT doing
121
+ # this. For sub-models, it tends to pass dicts (JSON objects) as JSON strings,
122
+ # which can be pre-parsed here.
123
+ for param_name in self.parameters["properties"]:
124
+ if isinstance(parsed_args.get(param_name, None), str):
125
+ try:
126
+ parsed_args[param_name] = json.loads(parsed_args[param_name])
127
+ except json.JSONDecodeError:
128
+ pass
129
+
130
+ type_adapter = get_cached_typeadapter(self.fn)
131
+ result = type_adapter.validate_python(parsed_args | injected_args)
132
+ if inspect.isawaitable(result):
133
+ result = await result
134
+
135
+ return _convert_to_content(result, serializer=self.serializer)
116
136
  except Exception as e:
117
137
  raise ToolError(f"Error executing tool {self.name}: {e}") from e
118
138
 
@@ -121,6 +141,7 @@ class Tool(BaseModel):
121
141
  "name": self.name,
122
142
  "description": self.description,
123
143
  "inputSchema": self.parameters,
144
+ "annotations": self.annotations,
124
145
  }
125
146
  return MCPTool(**kwargs | overrides)
126
147
 
@@ -132,6 +153,7 @@ class Tool(BaseModel):
132
153
 
133
154
  def _convert_to_content(
134
155
  result: Any,
156
+ serializer: Callable[[Any], str] | None = None,
135
157
  _process_as_single_item: bool = False,
136
158
  ) -> list[TextContent | ImageContent | EmbeddedResource]:
137
159
  """Convert a result to a sequence of content objects."""
@@ -166,23 +188,18 @@ def _convert_to_content(
166
188
 
167
189
  return other_content + mcp_types
168
190
 
169
- # if the result is a bytes object, convert it to a text content object
170
191
  if not isinstance(result, str):
171
- try:
172
- jsonable_result = pydantic_core.to_jsonable_python(result)
173
- if jsonable_result is None:
174
- return [TextContent(type="text", text="null")]
175
- elif isinstance(jsonable_result, bool):
176
- return [
177
- TextContent(
178
- type="text", text="true" if jsonable_result else "false"
179
- )
180
- ]
181
- elif isinstance(jsonable_result, str | int | float):
182
- return [TextContent(type="text", text=str(jsonable_result))]
183
- else:
184
- return [TextContent(type="text", text=json.dumps(jsonable_result))]
185
- except Exception:
186
- result = str(result)
192
+ if serializer is None:
193
+ result = default_serializer(result)
194
+ else:
195
+ try:
196
+ result = serializer(result)
197
+ except Exception as e:
198
+ logger.warning(
199
+ "Error serializing tool result: %s",
200
+ e,
201
+ exc_info=True,
202
+ )
203
+ result = default_serializer(result)
187
204
 
188
205
  return [TextContent(type="text", text=result)]
@@ -4,7 +4,7 @@ from collections.abc import Callable
4
4
  from typing import TYPE_CHECKING, Any
5
5
 
6
6
  from mcp.shared.context import LifespanContextT
7
- from mcp.types import EmbeddedResource, ImageContent, TextContent
7
+ from mcp.types import EmbeddedResource, ImageContent, TextContent, ToolAnnotations
8
8
 
9
9
  from fastmcp.exceptions import NotFoundError
10
10
  from fastmcp.settings import DuplicateBehavior
@@ -22,8 +22,13 @@ logger = get_logger(__name__)
22
22
  class ToolManager:
23
23
  """Manages FastMCP tools."""
24
24
 
25
- def __init__(self, duplicate_behavior: DuplicateBehavior | None = None):
25
+ def __init__(
26
+ self,
27
+ duplicate_behavior: DuplicateBehavior | None = None,
28
+ serializer: Callable[[Any], str] | None = None,
29
+ ):
26
30
  self._tools: dict[str, Tool] = {}
31
+ self._serializer = serializer
27
32
 
28
33
  # Default to "warn" if None is provided
29
34
  if duplicate_behavior is None:
@@ -61,9 +66,17 @@ class ToolManager:
61
66
  name: str | None = None,
62
67
  description: str | None = None,
63
68
  tags: set[str] | None = None,
69
+ annotations: ToolAnnotations | None = None,
64
70
  ) -> Tool:
65
71
  """Add a tool to the server."""
66
- tool = Tool.from_function(fn, name=name, description=description, tags=tags)
72
+ tool = Tool.from_function(
73
+ fn,
74
+ name=name,
75
+ description=description,
76
+ tags=tags,
77
+ annotations=annotations,
78
+ serializer=self._serializer,
79
+ )
67
80
  return self.add_tool(tool)
68
81
 
69
82
  def add_tool(self, tool: Tool, key: str | None = None) -> Tool:
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import (
4
+ asynccontextmanager,
5
+ )
6
+ from contextvars import ContextVar
7
+
8
+ from starlette.requests import Request
9
+
10
+ from fastmcp.utilities.logging import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ _current_starlette_request: ContextVar[Request | None] = ContextVar(
16
+ "starlette_request",
17
+ default=None,
18
+ )
19
+
20
+
21
+ @asynccontextmanager
22
+ async def starlette_request_context(request: Request):
23
+ token = _current_starlette_request.set(request)
24
+ try:
25
+ yield
26
+ finally:
27
+ _current_starlette_request.reset(token)
28
+
29
+
30
+ def get_current_starlette_request() -> Request | None:
31
+ return _current_starlette_request.get()
32
+
33
+
34
+ class RequestMiddleware:
35
+ """
36
+ Middleware that stores each request in a ContextVar
37
+ """
38
+
39
+ def __init__(self, app):
40
+ self.app = app
41
+
42
+ async def __call__(self, scope, receive, send):
43
+ async with starlette_request_context(Request(scope)):
44
+ await self.app(scope, receive, send)
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from collections.abc import Mapping, Sequence
5
+
6
+
7
+ def _prune_param(schema: dict, param: str) -> dict:
8
+ """Return a new schema with *param* removed from `properties`, `required`,
9
+ and (if no longer referenced) `$defs`.
10
+ """
11
+
12
+ # ── 1. drop from properties/required ──────────────────────────────
13
+ props = schema.get("properties", {})
14
+ removed = props.pop(param, None)
15
+ if removed is None: # nothing to do
16
+ return schema
17
+ # Keep empty properties object rather than removing it entirely
18
+ schema["properties"] = props
19
+ if param in schema.get("required", []):
20
+ schema["required"].remove(param)
21
+ if not schema["required"]:
22
+ schema.pop("required")
23
+
24
+ # ── 2. collect all remaining local $ref targets ───────────────────
25
+ used_defs: set[str] = set()
26
+
27
+ def walk(node: object) -> None: # depth-first traversal
28
+ if isinstance(node, Mapping):
29
+ ref = node.get("$ref")
30
+ if isinstance(ref, str) and ref.startswith("#/$defs/"):
31
+ used_defs.add(ref.split("/")[-1])
32
+ for v in node.values():
33
+ walk(v)
34
+ elif isinstance(node, Sequence) and not isinstance(node, str | bytes):
35
+ for v in node:
36
+ walk(v)
37
+
38
+ walk(schema)
39
+
40
+ # ── 3. remove orphaned definitions ────────────────────────────────
41
+ defs = schema.get("$defs", {})
42
+ for def_name in list(defs):
43
+ if def_name not in used_defs:
44
+ defs.pop(def_name)
45
+ if not defs:
46
+ schema.pop("$defs", None)
47
+
48
+ return schema
49
+
50
+
51
+ def prune_params(schema: dict, params: list[str]) -> dict:
52
+ """
53
+ Remove the given parameters from the schema.
54
+
55
+ """
56
+ schema = copy.deepcopy(schema)
57
+ for param in params:
58
+ schema = _prune_param(schema, param=param)
59
+ return schema
@@ -1001,53 +1001,153 @@ def format_description_with_responses(
1001
1001
  responses: dict[
1002
1002
  str, Any
1003
1003
  ], # Changed from specific ResponseInfo type to avoid circular imports
1004
+ parameters: list[openapi.ParameterInfo] | None = None, # Add parameters parameter
1005
+ request_body: openapi.RequestBodyInfo | None = None, # Add request_body parameter
1004
1006
  ) -> str:
1005
- """Formats the base description string with response information."""
1006
- if not responses:
1007
- return base_description
1007
+ """
1008
+ Formats the base description string with response, parameter, and request body information.
1009
+
1010
+ Args:
1011
+ base_description (str): The initial description to be formatted.
1012
+ responses (dict[str, Any]): A dictionary of response information, keyed by status code.
1013
+ parameters (list[openapi.ParameterInfo] | None, optional): A list of parameter information,
1014
+ including path and query parameters. Each parameter includes details such as name,
1015
+ location, whether it is required, and a description.
1016
+ request_body (openapi.RequestBodyInfo | None, optional): Information about the request body,
1017
+ including its description, whether it is required, and its content schema.
1008
1018
 
1019
+ Returns:
1020
+ str: The formatted description string with additional details about responses, parameters,
1021
+ and the request body.
1022
+ """
1009
1023
  desc_parts = [base_description]
1010
- response_section = "\n\n**Responses:**"
1011
- added_response_section = False
1012
1024
 
1013
- # Determine success codes (common ones)
1014
- success_codes = {"200", "201", "202", "204"} # As strings
1015
- success_status = next((s for s in success_codes if s in responses), None)
1025
+ # Add parameter information
1026
+ if parameters:
1027
+ # Process path parameters
1028
+ path_params = [p for p in parameters if p.location == "path"]
1029
+ if path_params:
1030
+ param_section = "\n\n**Path Parameters:**"
1031
+ desc_parts.append(param_section)
1032
+ for param in path_params:
1033
+ required_marker = " (Required)" if param.required else ""
1034
+ param_desc = f"\n- **{param.name}**{required_marker}: {param.description or 'No description.'}"
1035
+ desc_parts.append(param_desc)
1036
+
1037
+ # Process query parameters
1038
+ query_params = [p for p in parameters if p.location == "query"]
1039
+ if query_params:
1040
+ param_section = "\n\n**Query Parameters:**"
1041
+ desc_parts.append(param_section)
1042
+ for param in query_params:
1043
+ required_marker = " (Required)" if param.required else ""
1044
+ param_desc = f"\n- **{param.name}**{required_marker}: {param.description or 'No description.'}"
1045
+ desc_parts.append(param_desc)
1046
+
1047
+ # Add request body information if present
1048
+ if request_body and request_body.description:
1049
+ req_body_section = "\n\n**Request Body:**"
1050
+ desc_parts.append(req_body_section)
1051
+ required_marker = " (Required)" if request_body.required else ""
1052
+ desc_parts.append(f"\n{request_body.description}{required_marker}")
1053
+
1054
+ # Add request body property descriptions if available
1055
+ if request_body.content_schema:
1056
+ media_type = (
1057
+ "application/json"
1058
+ if "application/json" in request_body.content_schema
1059
+ else next(iter(request_body.content_schema), None)
1060
+ )
1061
+ if media_type:
1062
+ schema = request_body.content_schema.get(media_type, {})
1063
+ if isinstance(schema, dict) and "properties" in schema:
1064
+ desc_parts.append("\n\n**Request Properties:**")
1065
+ for prop_name, prop_schema in schema["properties"].items():
1066
+ if (
1067
+ isinstance(prop_schema, dict)
1068
+ and "description" in prop_schema
1069
+ ):
1070
+ required = prop_name in schema.get("required", [])
1071
+ req_mark = " (Required)" if required else ""
1072
+ desc_parts.append(
1073
+ f"\n- **{prop_name}**{req_mark}: {prop_schema['description']}"
1074
+ )
1016
1075
 
1017
- # Process all responses
1018
- responses_to_process = responses.items()
1076
+ # Add response information
1077
+ if responses:
1078
+ response_section = "\n\n**Responses:**"
1079
+ added_response_section = False
1019
1080
 
1020
- for status_code, resp_info in sorted(responses_to_process):
1021
- if not added_response_section:
1022
- desc_parts.append(response_section)
1023
- added_response_section = True
1081
+ # Determine success codes (common ones)
1082
+ success_codes = {"200", "201", "202", "204"} # As strings
1083
+ success_status = next((s for s in success_codes if s in responses), None)
1024
1084
 
1025
- status_marker = " (Success)" if status_code == success_status else ""
1026
- desc_parts.append(
1027
- f"\n- **{status_code}**{status_marker}: {resp_info.description or 'No description.'}"
1028
- )
1085
+ # Process all responses
1086
+ responses_to_process = responses.items()
1029
1087
 
1030
- # Process content schemas for this response
1031
- if resp_info.content_schema:
1032
- # Prioritize json, then take first available
1033
- media_type = (
1034
- "application/json"
1035
- if "application/json" in resp_info.content_schema
1036
- else next(iter(resp_info.content_schema), None)
1088
+ for status_code, resp_info in sorted(responses_to_process):
1089
+ if not added_response_section:
1090
+ desc_parts.append(response_section)
1091
+ added_response_section = True
1092
+
1093
+ status_marker = " (Success)" if status_code == success_status else ""
1094
+ desc_parts.append(
1095
+ f"\n- **{status_code}**{status_marker}: {resp_info.description or 'No description.'}"
1037
1096
  )
1038
1097
 
1039
- if media_type:
1040
- schema = resp_info.content_schema.get(media_type)
1041
- desc_parts.append(f" - Content-Type: `{media_type}`")
1098
+ # Process content schemas for this response
1099
+ if resp_info.content_schema:
1100
+ # Prioritize json, then take first available
1101
+ media_type = (
1102
+ "application/json"
1103
+ if "application/json" in resp_info.content_schema
1104
+ else next(iter(resp_info.content_schema), None)
1105
+ )
1106
+
1107
+ if media_type:
1108
+ schema = resp_info.content_schema.get(media_type)
1109
+ desc_parts.append(f" - Content-Type: `{media_type}`")
1110
+
1111
+ # Add response property descriptions
1112
+ if isinstance(schema, dict):
1113
+ # Handle array responses
1114
+ if schema.get("type") == "array" and "items" in schema:
1115
+ items_schema = schema["items"]
1116
+ if (
1117
+ isinstance(items_schema, dict)
1118
+ and "properties" in items_schema
1119
+ ):
1120
+ desc_parts.append("\n - **Response Item Properties:**")
1121
+ for prop_name, prop_schema in items_schema[
1122
+ "properties"
1123
+ ].items():
1124
+ if (
1125
+ isinstance(prop_schema, dict)
1126
+ and "description" in prop_schema
1127
+ ):
1128
+ desc_parts.append(
1129
+ f"\n - **{prop_name}**: {prop_schema['description']}"
1130
+ )
1131
+ # Handle object responses
1132
+ elif "properties" in schema:
1133
+ desc_parts.append("\n - **Response Properties:**")
1134
+ for prop_name, prop_schema in schema["properties"].items():
1135
+ if (
1136
+ isinstance(prop_schema, dict)
1137
+ and "description" in prop_schema
1138
+ ):
1139
+ desc_parts.append(
1140
+ f"\n - **{prop_name}**: {prop_schema['description']}"
1141
+ )
1042
1142
 
1043
- if schema:
1044
1143
  # Generate Example
1045
- example = generate_example_from_schema(schema)
1046
- if example != "unknown_type" and example is not None:
1047
- desc_parts.append("\n - **Example:**")
1048
- desc_parts.append(
1049
- format_json_for_description(example, indent=2)
1050
- )
1144
+ if schema:
1145
+ example = generate_example_from_schema(schema)
1146
+ if example != "unknown_type" and example is not None:
1147
+ desc_parts.append("\n - **Example:**")
1148
+ desc_parts.append(
1149
+ format_json_for_description(example, indent=2)
1150
+ )
1051
1151
 
1052
1152
  return "\n".join(desc_parts)
1053
1153
 
@@ -1069,7 +1169,15 @@ def _combine_schemas(route: openapi.HTTPRoute) -> dict[str, Any]:
1069
1169
  for param in route.parameters:
1070
1170
  if param.required:
1071
1171
  required.append(param.name)
1072
- properties[param.name] = param.schema_
1172
+
1173
+ # Copy the schema and add description if available
1174
+ param_schema = param.schema_.copy() if isinstance(param.schema_, dict) else {}
1175
+
1176
+ # Add parameter description to schema if available and not already present
1177
+ if param.description and not param_schema.get("description"):
1178
+ param_schema["description"] = param.description
1179
+
1180
+ properties[param.name] = param_schema
1073
1181
 
1074
1182
  # Add request body if it exists
1075
1183
  if route.request_body and route.request_body.content_schema:
@@ -1077,8 +1185,11 @@ def _combine_schemas(route: openapi.HTTPRoute) -> dict[str, Any]:
1077
1185
  content_type = next(iter(route.request_body.content_schema))
1078
1186
  body_schema = route.request_body.content_schema[content_type]
1079
1187
  body_props = body_schema.get("properties", {})
1188
+
1189
+ # Add request body properties
1080
1190
  for prop_name, prop_schema in body_props.items():
1081
1191
  properties[prop_name] = prop_schema
1192
+
1082
1193
  if route.request_body.required:
1083
1194
  required.extend(body_schema.get("required", []))
1084
1195
 
@@ -1,14 +1,79 @@
1
1
  """Common types used across FastMCP."""
2
2
 
3
3
  import base64
4
+ import inspect
5
+ from collections.abc import Callable
6
+ from functools import lru_cache
4
7
  from pathlib import Path
5
- from typing import TypeVar
8
+ from types import UnionType
9
+ from typing import Annotated, TypeVar, Union, get_args, get_origin
6
10
 
7
11
  from mcp.types import ImageContent
12
+ from pydantic import TypeAdapter
8
13
 
9
14
  T = TypeVar("T")
10
15
 
11
16
 
17
+ @lru_cache(maxsize=5000)
18
+ def get_cached_typeadapter(cls: T) -> TypeAdapter[T]:
19
+ """
20
+ TypeAdapters are heavy objects, and in an application context we'd typically
21
+ create them once in a global scope and reuse them as often as possible.
22
+ However, this isn't feasible for user-generated functions. Instead, we use a
23
+ cache to minimize the cost of creating them as much as possible.
24
+ """
25
+ return TypeAdapter(cls)
26
+
27
+
28
+ def issubclass_safe(cls: type, base: type) -> bool:
29
+ """Check if cls is a subclass of base, even if cls is a type variable."""
30
+ try:
31
+ if origin := get_origin(cls):
32
+ return issubclass_safe(origin, base)
33
+ return issubclass(cls, base)
34
+ except TypeError:
35
+ return False
36
+
37
+
38
+ def is_class_member_of_type(cls: type, base: type) -> bool:
39
+ """
40
+ Check if cls is a member of base, even if cls is a type variable.
41
+
42
+ Base can be a type, a UnionType, or an Annotated type. Generic types are not
43
+ considered members (e.g. T is not a member of list[T]).
44
+ """
45
+ origin = get_origin(cls)
46
+ # Handle both types of unions: UnionType (from types module, used with | syntax)
47
+ # and typing.Union (used with Union[] syntax)
48
+ if origin is UnionType or origin == Union:
49
+ return any(is_class_member_of_type(arg, base) for arg in get_args(cls))
50
+ elif origin is Annotated:
51
+ # For Annotated[T, ...], check if T is a member of base
52
+ args = get_args(cls)
53
+ if args:
54
+ return is_class_member_of_type(args[0], base)
55
+ return False
56
+ else:
57
+ return issubclass_safe(cls, base)
58
+
59
+
60
+ def find_kwarg_by_type(fn: Callable, kwarg_type: type) -> str | None:
61
+ """
62
+ Find the name of the kwarg that is of type kwarg_type.
63
+
64
+ Includes union types that contain the kwarg_type, as well as Annotated types.
65
+ """
66
+ if inspect.ismethod(fn) and hasattr(fn, "__func__"):
67
+ sig = inspect.signature(fn.__func__)
68
+ else:
69
+ sig = inspect.signature(fn)
70
+
71
+ for name, param in sig.parameters.items():
72
+ if is_class_member_of_type(param.annotation, kwarg_type):
73
+ return name
74
+ return None
75
+
76
+
12
77
  def _convert_set_defaults(maybe_set: set[T] | list[T] | None) -> set[T]:
13
78
  """Convert a set or list to a set, defaulting to an empty set if None."""
14
79
  if maybe_set is None: