chatlas 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
chatlas/_provider.py CHANGED
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from typing import (
5
- Any,
6
5
  AsyncIterable,
7
6
  Generic,
8
7
  Iterable,
@@ -17,6 +16,7 @@ from pydantic import BaseModel
17
16
  from ._content import Content
18
17
  from ._tools import Tool
19
18
  from ._turn import Turn
19
+ from ._typing_extensions import TypedDict
20
20
 
21
21
  ChatCompletionT = TypeVar("ChatCompletionT")
22
22
  ChatCompletionChunkT = TypeVar("ChatCompletionChunkT")
@@ -24,8 +24,52 @@ ChatCompletionChunkT = TypeVar("ChatCompletionChunkT")
24
24
  ChatCompletionDictT = TypeVar("ChatCompletionDictT")
25
25
 
26
26
 
27
+ class AnyTypeDict(TypedDict, total=False):
28
+ pass
29
+
30
+
31
+ SubmitInputArgsT = TypeVar("SubmitInputArgsT", bound=AnyTypeDict)
32
+ """
33
+ A TypedDict representing the provider specific arguments that can specified when
34
+ submitting input to a model provider.
35
+ """
36
+
37
+
38
+ class StandardModelParams(TypedDict, total=False):
39
+ """
40
+ A TypedDict representing the standard model parameters that can be set
41
+ when using a [](`~chatlas.Chat`) instance.
42
+ """
43
+
44
+ temperature: float
45
+ top_p: float
46
+ top_k: int
47
+ frequency_penalty: float
48
+ presence_penalty: float
49
+ seed: int
50
+ max_tokens: int
51
+ log_probs: bool
52
+ stop_sequences: list[str]
53
+
54
+
55
+ StandardModelParamNames = Literal[
56
+ "temperature",
57
+ "top_p",
58
+ "top_k",
59
+ "frequency_penalty",
60
+ "presence_penalty",
61
+ "seed",
62
+ "max_tokens",
63
+ "log_probs",
64
+ "stop_sequences",
65
+ ]
66
+
67
+
27
68
  class Provider(
28
- ABC, Generic[ChatCompletionT, ChatCompletionChunkT, ChatCompletionDictT]
69
+ ABC,
70
+ Generic[
71
+ ChatCompletionT, ChatCompletionChunkT, ChatCompletionDictT, SubmitInputArgsT
72
+ ],
29
73
  ):
30
74
  """
31
75
  A model provider interface for a [](`~chatlas.Chat`).
@@ -40,6 +84,24 @@ class Provider(
40
84
  directly.
41
85
  """
42
86
 
87
+ def __init__(self, *, name: str, model: str):
88
+ self._name = name
89
+ self._model = model
90
+
91
+ @property
92
+ def name(self):
93
+ """
94
+ Get the name of the provider
95
+ """
96
+ return self._name
97
+
98
+ @property
99
+ def model(self):
100
+ """
101
+ Get the model used by the provider
102
+ """
103
+ return self._model
104
+
43
105
  @overload
44
106
  @abstractmethod
45
107
  def chat_perform(
@@ -49,7 +111,7 @@ class Provider(
49
111
  turns: list[Turn],
50
112
  tools: dict[str, Tool],
51
113
  data_model: Optional[type[BaseModel]],
52
- kwargs: Any,
114
+ kwargs: SubmitInputArgsT,
53
115
  ) -> ChatCompletionT: ...
54
116
 
55
117
  @overload
@@ -61,7 +123,7 @@ class Provider(
61
123
  turns: list[Turn],
62
124
  tools: dict[str, Tool],
63
125
  data_model: Optional[type[BaseModel]],
64
- kwargs: Any,
126
+ kwargs: SubmitInputArgsT,
65
127
  ) -> Iterable[ChatCompletionChunkT]: ...
66
128
 
67
129
  @abstractmethod
@@ -72,7 +134,7 @@ class Provider(
72
134
  turns: list[Turn],
73
135
  tools: dict[str, Tool],
74
136
  data_model: Optional[type[BaseModel]],
75
- kwargs: Any,
137
+ kwargs: SubmitInputArgsT,
76
138
  ) -> Iterable[ChatCompletionChunkT] | ChatCompletionT: ...
77
139
 
78
140
  @overload
@@ -84,7 +146,7 @@ class Provider(
84
146
  turns: list[Turn],
85
147
  tools: dict[str, Tool],
86
148
  data_model: Optional[type[BaseModel]],
87
- kwargs: Any,
149
+ kwargs: SubmitInputArgsT,
88
150
  ) -> ChatCompletionT: ...
89
151
 
90
152
  @overload
@@ -96,7 +158,7 @@ class Provider(
96
158
  turns: list[Turn],
97
159
  tools: dict[str, Tool],
98
160
  data_model: Optional[type[BaseModel]],
99
- kwargs: Any,
161
+ kwargs: SubmitInputArgsT,
100
162
  ) -> AsyncIterable[ChatCompletionChunkT]: ...
101
163
 
102
164
  @abstractmethod
@@ -107,7 +169,7 @@ class Provider(
107
169
  turns: list[Turn],
108
170
  tools: dict[str, Tool],
109
171
  data_model: Optional[type[BaseModel]],
110
- kwargs: Any,
172
+ kwargs: SubmitInputArgsT,
111
173
  ) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ...
112
174
 
113
175
  @abstractmethod
@@ -149,3 +211,11 @@ class Provider(
149
211
  tools: dict[str, Tool],
150
212
  data_model: Optional[type[BaseModel]],
151
213
  ) -> int: ...
214
+
215
+ @abstractmethod
216
+ def translate_model_params(
217
+ self, params: StandardModelParams
218
+ ) -> SubmitInputArgsT: ...
219
+
220
+ @abstractmethod
221
+ def supported_model_params(self) -> set[StandardModelParamNames]: ...
chatlas/_snowflake.py CHANGED
@@ -20,10 +20,10 @@ from ._content import (
20
20
  ContentToolResult,
21
21
  )
22
22
  from ._logging import log_model_default
23
- from ._provider import Provider
23
+ from ._provider import Provider, StandardModelParamNames, StandardModelParams
24
24
  from ._tokens import tokens_log
25
25
  from ._tools import Tool, basemodel_to_param_schema
26
- from ._turn import Turn, normalize_turns
26
+ from ._turn import Turn
27
27
  from ._utils import drop_none
28
28
 
29
29
  if TYPE_CHECKING:
@@ -61,7 +61,6 @@ def ChatSnowflake(
61
61
  *,
62
62
  system_prompt: Optional[str] = None,
63
63
  model: Optional[str] = None,
64
- turns: Optional[list[Turn]] = None,
65
64
  connection_name: Optional[str] = None,
66
65
  account: Optional[str] = None,
67
66
  user: Optional[str] = None,
@@ -111,13 +110,6 @@ def ChatSnowflake(
111
110
  The model to use for the chat. The default, None, will pick a reasonable
112
111
  default, and warn you about it. We strongly recommend explicitly
113
112
  choosing a model for all but the most casual use.
114
- turns
115
- A list of turns to start the chat with (i.e., continuing a previous
116
- conversation). If not provided, the conversation begins from scratch. Do
117
- not provide non-None values for both `turns` and `system_prompt`. Each
118
- message in the list should be a dictionary with at least `role` (usually
119
- `system`, `user`, or `assistant`, but `tool` is also possible). Normally
120
- there is also a `content` field, which is a string.
121
113
  connection_name
122
114
  The name of the connection (i.e., section) within the connections.toml file.
123
115
  This is useful if you want to keep your credentials in a connections.toml file
@@ -157,14 +149,13 @@ def ChatSnowflake(
157
149
  private_key_file_pwd=private_key_file_pwd,
158
150
  kwargs=kwargs,
159
151
  ),
160
- turns=normalize_turns(
161
- turns or [],
162
- system_prompt,
163
- ),
152
+ system_prompt=system_prompt,
164
153
  )
165
154
 
166
155
 
167
- class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChunk"]):
156
+ class SnowflakeProvider(
157
+ Provider["Completion", "CompletionChunk", "CompletionChunk", "CompleteRequest"]
158
+ ):
168
159
  def __init__(
169
160
  self,
170
161
  *,
@@ -175,6 +166,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
175
166
  password: Optional[str],
176
167
  private_key_file: Optional[str],
177
168
  private_key_file_pwd: Optional[str],
169
+ name: str = "Snowflake",
178
170
  kwargs: Optional[dict[str, "str | int"]],
179
171
  ):
180
172
  try:
@@ -185,6 +177,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
185
177
  "`ChatSnowflake()` requires the `snowflake-ml-python` package. "
186
178
  "Please install it via `pip install snowflake-ml-python`."
187
179
  )
180
+ super().__init__(name=name, model=model)
188
181
 
189
182
  configs: dict[str, str | int] = drop_none(
190
183
  {
@@ -198,8 +191,6 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
198
191
  }
199
192
  )
200
193
 
201
- self._model = model
202
-
203
194
  session = Session.builder.configs(configs).create()
204
195
  self._cortex_service = Root(session).cortex_inference_service
205
196
 
@@ -314,7 +305,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
314
305
  from snowflake.core.cortex.inference_service import CompleteRequest
315
306
 
316
307
  req = CompleteRequest(
317
- model=self._model,
308
+ model=self.model,
318
309
  messages=self._as_request_messages(turns),
319
310
  stream=stream,
320
311
  )
@@ -599,6 +590,26 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
599
590
 
600
591
  return models.Tool(tool_spec=spec)
601
592
 
593
+ def translate_model_params(self, params: StandardModelParams) -> "CompleteRequest":
594
+ res: "CompleteRequest" = {}
595
+ if "temperature" in params:
596
+ res["temperature"] = params["temperature"]
597
+
598
+ if "top_p" in params:
599
+ res["top_p"] = params["top_p"]
600
+
601
+ if "max_tokens" in params:
602
+ res["max_tokens"] = params["max_tokens"]
603
+
604
+ return res
605
+
606
+ def supported_model_params(self) -> set[StandardModelParamNames]:
607
+ return {
608
+ "temperature",
609
+ "top_p",
610
+ "max_tokens",
611
+ }
612
+
602
613
 
603
614
  # Yield parsed event data from the Snowflake SSEClient
604
615
  # (this is only needed for the streaming case).
chatlas/_tokens.py CHANGED
@@ -1,9 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import copy
4
+ import importlib.resources as resources
5
+ import warnings
4
6
  from threading import Lock
5
7
  from typing import TYPE_CHECKING
6
8
 
9
+ import orjson
10
+
7
11
  from ._logging import logger
8
12
  from ._typing_extensions import TypedDict
9
13
 
@@ -17,8 +21,10 @@ class TokenUsage(TypedDict):
17
21
  """
18
22
 
19
23
  name: str
24
+ model: str
20
25
  input: int
21
26
  output: int
27
+ cost: float | None
22
28
 
23
29
 
24
30
  class ThreadSafeTokenCounter:
@@ -26,7 +32,9 @@ class ThreadSafeTokenCounter:
26
32
  self._lock = Lock()
27
33
  self._tokens: dict[str, TokenUsage] = {}
28
34
 
29
- def log_tokens(self, name: str, input_tokens: int, output_tokens: int) -> None:
35
+ def log_tokens(
36
+ self, name: str, model: str, input_tokens: int, output_tokens: int
37
+ ) -> None:
30
38
  logger.info(
31
39
  f"Provider '{name}' generated a response of {output_tokens} tokens "
32
40
  f"from an input of {input_tokens} tokens."
@@ -36,12 +44,21 @@ class ThreadSafeTokenCounter:
36
44
  if name not in self._tokens:
37
45
  self._tokens[name] = {
38
46
  "name": name,
47
+ "model": model,
39
48
  "input": input_tokens,
40
49
  "output": output_tokens,
50
+ "cost": compute_price(name, model, input_tokens, output_tokens),
41
51
  }
42
52
  else:
43
53
  self._tokens[name]["input"] += input_tokens
44
54
  self._tokens[name]["output"] += output_tokens
55
+ price = compute_price(name, model, input_tokens, output_tokens)
56
+ if price is not None:
57
+ cost = self._tokens[name]["cost"]
58
+ if cost is None:
59
+ self._tokens[name]["cost"] = price
60
+ else:
61
+ self._tokens[name]["cost"] = cost + price
45
62
 
46
63
  def get_usage(self) -> list[TokenUsage] | None:
47
64
  with self._lock:
@@ -59,8 +76,7 @@ def tokens_log(provider: "Provider", tokens: tuple[int, int]) -> None:
59
76
  """
60
77
  Log token usage for a provider in a thread-safe manner.
61
78
  """
62
- name = provider.__class__.__name__.replace("Provider", "")
63
- _token_counter.log_tokens(name, tokens[0], tokens[1])
79
+ _token_counter.log_tokens(provider.name, provider.model, tokens[0], tokens[1])
64
80
 
65
81
 
66
82
  def tokens_reset() -> None:
@@ -71,17 +87,89 @@ def tokens_reset() -> None:
71
87
  _token_counter = ThreadSafeTokenCounter()
72
88
 
73
89
 
90
+ class TokenPrice(TypedDict):
91
+ """
92
+ Defines the necessary information to look up pricing for a given turn.
93
+ """
94
+
95
+ provider: str
96
+ """The provider name (e.g., "OpenAI", "Anthropic", etc.)"""
97
+ model: str
98
+ """The model name (e.g., "gpt-3.5-turbo", "claude-2", etc.)"""
99
+ cached_input: float
100
+ """The cost per user token in USD per million tokens for cached input"""
101
+ input: float
102
+ """The cost per user token in USD per million tokens"""
103
+ output: float
104
+ """The cost per assistant token in USD per million tokens"""
105
+
106
+
107
+ # Load in pricing pulled from ellmer
108
+ f = resources.files("chatlas").joinpath("data/prices.json").read_text(encoding="utf-8")
109
+ pricing_list: list[TokenPrice] = orjson.loads(f)
110
+
111
+
112
+ def get_token_pricing(name: str, model: str) -> TokenPrice | None:
113
+ """
114
+ Get token pricing information given a provider name and model
115
+
116
+ Note
117
+ ----
118
+ Only a subset of providers and models and currently supported.
119
+ The pricing information derives from ellmer.
120
+
121
+ Returns
122
+ -------
123
+ TokenPrice | None
124
+ """
125
+ result = next(
126
+ (
127
+ item
128
+ for item in pricing_list
129
+ if item["provider"] == name and item["model"] == model
130
+ ),
131
+ None,
132
+ )
133
+ if result is None:
134
+ warnings.warn(
135
+ f"Token pricing for the provider '{name}' and model '{model}' you selected is not available. "
136
+ "Please check the provider's documentation."
137
+ )
138
+
139
+ return result
140
+
141
+
142
+ def compute_price(
143
+ name: str, model: str, input_tokens: int, output_tokens: int
144
+ ) -> float | None:
145
+ """
146
+ Compute the cost of a turn.
147
+
148
+ Returns
149
+ -------
150
+ float | None
151
+ The cost of the turn in USD, or None if the cost could not be calculated.
152
+ """
153
+ price = get_token_pricing(name, model)
154
+ if price is None:
155
+ return None
156
+ input_price = input_tokens * (price["input"] / 1e6)
157
+ output_price = output_tokens * (price["output"] / 1e6)
158
+ return input_price + output_price
159
+
160
+
74
161
  def token_usage() -> list[TokenUsage] | None:
75
162
  """
76
163
  Report on token usage in the current session
77
164
 
78
165
  Call this function to find out the cumulative number of tokens that you
79
- have sent and received in the current session.
166
+ have sent and received in the current session. The price will be shown if known
80
167
 
81
168
  Returns
82
169
  -------
83
170
  list[TokenUsage] | None
84
- A list of dictionaries with the following keys: "name", "input", and "output".
171
+ A list of dictionaries with the following keys: "name", "input", "output", and "cost".
172
+ If no cost data is available for the name/model combination chosen, then "cost" will be None.
85
173
  If no tokens have been logged, then None is returned.
86
174
  """
87
175
  return _token_counter.get_usage()
chatlas/_tools.py CHANGED
@@ -2,11 +2,17 @@ from __future__ import annotations
2
2
 
3
3
  import inspect
4
4
  import warnings
5
- from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional
5
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Optional
6
6
 
7
+ import openai
7
8
  from pydantic import BaseModel, Field, create_model
8
9
 
9
10
  from . import _utils
11
+ from ._content import (
12
+ ContentToolResult,
13
+ ContentToolResultImage,
14
+ ContentToolResultResource,
15
+ )
10
16
 
11
17
  __all__ = (
12
18
  "Tool",
@@ -14,6 +20,8 @@ __all__ = (
14
20
  )
15
21
 
16
22
  if TYPE_CHECKING:
23
+ from mcp import ClientSession as MCPClientSession
24
+ from mcp import Tool as MCPTool
17
25
  from openai.types.chat import ChatCompletionToolParam
18
26
 
19
27
 
@@ -28,26 +36,168 @@ class Tool:
28
36
  ----------
29
37
  func
30
38
  The function to be invoked when the tool is called.
31
- model
32
- A Pydantic model that describes the input parameters for the function.
33
- If not provided, the model will be inferred from the function's type hints.
34
- The primary reason why you might want to provide a model in
35
- Note that the name and docstring of the model takes precedence over the
36
- name and docstring of the function.
39
+ name
40
+ The name of the tool.
41
+ description
42
+ A description of what the tool does.
43
+ parameters
44
+ A dictionary describing the input parameters and their types.
37
45
  """
38
46
 
39
47
  func: Callable[..., Any] | Callable[..., Awaitable[Any]]
40
48
 
41
49
  def __init__(
42
50
  self,
43
- func: Callable[..., Any] | Callable[..., Awaitable[Any]],
44
51
  *,
45
- model: Optional[type[BaseModel]] = None,
52
+ func: Callable[..., Any] | Callable[..., Awaitable[Any]],
53
+ name: str,
54
+ description: str,
55
+ parameters: dict[str, Any],
46
56
  ):
57
+ self.name = name
47
58
  self.func = func
48
59
  self._is_async = _utils.is_async_callable(func)
49
- self.schema = func_to_schema(func, model)
50
- self.name = self.schema["function"]["name"]
60
+ self.schema: "ChatCompletionToolParam" = {
61
+ "type": "function",
62
+ "function": {
63
+ "name": name,
64
+ "description": description,
65
+ "parameters": parameters,
66
+ },
67
+ }
68
+
69
+ @classmethod
70
+ def from_func(
71
+ cls: type["Tool"],
72
+ func: Callable[..., Any] | Callable[..., Awaitable[Any]],
73
+ *,
74
+ model: Optional[type[BaseModel]] = None,
75
+ ) -> "Tool":
76
+ """
77
+ Create a Tool from a Python function
78
+
79
+ Parameters
80
+ ----------
81
+ func
82
+ The function to wrap as a tool.
83
+ model
84
+ A Pydantic model that describes the input parameters for the function.
85
+ If not provided, the model will be inferred from the function's type hints.
86
+ The primary reason why you might want to provide a model in
87
+ Note that the name and docstring of the model takes precedence over the
88
+ name and docstring of the function.
89
+
90
+ Returns
91
+ -------
92
+ Tool
93
+ A new Tool instance wrapping the provided function.
94
+
95
+ Raises
96
+ ------
97
+ ValueError
98
+ If there is a mismatch between model fields and function parameters.
99
+ """
100
+
101
+ if model is None:
102
+ model = func_to_basemodel(func)
103
+
104
+ # Throw if there is a mismatch between the model and the function parameters
105
+ params = inspect.signature(func).parameters
106
+ fields = model.model_fields
107
+ diff = set(params) ^ set(fields)
108
+ if diff:
109
+ raise ValueError(
110
+ f"`model` fields must match tool function parameters exactly. "
111
+ f"Fields found in one but not the other: {diff}"
112
+ )
113
+
114
+ params = basemodel_to_param_schema(model)
115
+
116
+ return cls(
117
+ func=func,
118
+ name=model.__name__ or func.__name__,
119
+ description=model.__doc__ or func.__doc__ or "",
120
+ parameters=params,
121
+ )
122
+
123
+ @classmethod
124
+ def from_mcp(
125
+ cls: type["Tool"],
126
+ session: "MCPClientSession",
127
+ mcp_tool: "MCPTool",
128
+ ) -> "Tool":
129
+ """
130
+ Create a Tool from an MCP tool
131
+
132
+ Parameters
133
+ ----------
134
+ session
135
+ The MCP client session to use for calling the tool.
136
+ mcp_tool
137
+ The MCP tool to wrap.
138
+
139
+ Returns
140
+ -------
141
+ Tool
142
+ A new Tool instance wrapping the MCP tool.
143
+ """
144
+
145
+ async def _call(**args: Any) -> AsyncGenerator[ContentToolResult, None]:
146
+ result = await session.call_tool(mcp_tool.name, args)
147
+
148
+ # Raise an error if the tool call resulted in an error. It doesn't seem to be
149
+ # very well defined how to get at the error message, but it appears that it gets
150
+ # stored in the `text` attribute of the content. Also, empirically, the error
151
+ # message seems to include `Error executing tool {tool_name}: ...`, so
152
+ if result.isError:
153
+ err_msg = getattr(
154
+ result.content[0],
155
+ "text",
156
+ f"Error executing tool {mcp_tool.name}.",
157
+ )
158
+ raise RuntimeError(err_msg)
159
+
160
+ for content in result.content:
161
+ if content.type == "text":
162
+ yield ContentToolResult(value=content.text)
163
+ elif content.type == "image":
164
+ if content.mimeType not in (
165
+ "image/png",
166
+ "image/jpeg",
167
+ "image/webp",
168
+ "image/gif",
169
+ ):
170
+ raise ValueError(
171
+ f"Unsupported image MIME type: {content.mimeType}"
172
+ )
173
+
174
+ yield ContentToolResultImage(
175
+ value=content.data,
176
+ mime_type=content.mimeType,
177
+ )
178
+ elif content.type == "resource":
179
+ from mcp.types import TextResourceContents
180
+
181
+ resource = content.resource
182
+ if isinstance(resource, TextResourceContents):
183
+ blob = resource.text.encode("utf-8")
184
+ else:
185
+ blob = resource.blob.encode("utf-8")
186
+
187
+ yield ContentToolResultResource(
188
+ value=blob, mime_type=content.resource.mimeType
189
+ )
190
+ else:
191
+ raise RuntimeError(f"Unexpected content type: {content.type}")
192
+
193
+ params = mcp_tool_input_schema_to_param_schema(mcp_tool.inputSchema)
194
+
195
+ return cls(
196
+ func=_utils.wrap_async(_call),
197
+ name=mcp_tool.name,
198
+ description=mcp_tool.description or "",
199
+ parameters=params,
200
+ )
51
201
 
52
202
 
53
203
  class ToolRejectError(Exception):
@@ -160,14 +310,6 @@ def func_to_basemodel(func: Callable) -> type[BaseModel]:
160
310
 
161
311
 
162
312
  def basemodel_to_param_schema(model: type[BaseModel]) -> dict[str, object]:
163
- try:
164
- import openai
165
- except ImportError:
166
- raise ImportError(
167
- "The openai package is required for this functionality. "
168
- "Please install it with `pip install openai`."
169
- )
170
-
171
313
  # Lean on openai's ability to translate BaseModel.model_json_schema()
172
314
  # to a valid tool schema (this wouldn't be impossible to do ourselves,
173
315
  # but it's fair amount of logic to substitute `$refs`, etc.)
@@ -177,10 +319,27 @@ def basemodel_to_param_schema(model: type[BaseModel]) -> dict[str, object]:
177
319
  if "parameters" not in fn:
178
320
  raise ValueError("Expected `parameters` in function definition.")
179
321
 
180
- params = fn["parameters"]
322
+ params = rm_param_titles(fn["parameters"])
323
+
324
+ return params
325
+
326
+
327
+ def mcp_tool_input_schema_to_param_schema(
328
+ input_schema: dict[str, Any],
329
+ ) -> dict[str, object]:
330
+ params = rm_param_titles(input_schema)
331
+
332
+ if "additionalProperties" not in params:
333
+ params["additionalProperties"] = False
334
+
335
+ return params
336
+
181
337
 
182
- # For some reason, openai (or pydantic?) wants to include a title
183
- # at the model and field level. I don't think we actually need or want this.
338
+ def rm_param_titles(
339
+ params: dict[str, object],
340
+ ) -> dict[str, object]:
341
+ # For some reason, pydantic wants to include a title at the model and field
342
+ # level. I don't think we actually need or want this.
184
343
  if "title" in params:
185
344
  del params["title"]
186
345