chatlas 0.8.1__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chatlas might be problematic. Click here for more details.

chatlas/_chat.py CHANGED
@@ -42,27 +42,41 @@ from ._display import (
42
42
  MockMarkdownDisplay,
43
43
  )
44
44
  from ._logging import log_tool_error
45
- from ._provider import Provider
45
+ from ._mcp_manager import MCPSessionManager
46
+ from ._provider import Provider, StandardModelParams, SubmitInputArgsT
47
+ from ._tokens import compute_cost, get_token_pricing
46
48
  from ._tools import Tool, ToolRejectError
47
49
  from ._turn import Turn, user_turn
48
- from ._typing_extensions import TypedDict
49
- from ._utils import html_escape, wrap_async
50
+ from ._typing_extensions import TypedDict, TypeGuard
51
+ from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async
50
52
 
51
53
 
52
- class AnyTypeDict(TypedDict, total=False):
53
- pass
54
+ class TokensDict(TypedDict):
55
+ """
56
+ A TypedDict representing the token counts for a turn in the chat.
57
+ This is used to represent the token counts for each turn in the chat.
58
+ `role` represents the role of the turn (i.e., "user" or "assistant").
59
+ `tokens` represents the new tokens used in the turn.
60
+ `tokens_total` represents the total tokens used in the turn.
61
+ Ex. A new user input of 2 tokens is sent, plus 10 tokens of context from prior turns (input and output).
62
+ This would have a `tokens_total` of 12.
63
+ """
54
64
 
65
+ role: Literal["user", "assistant"]
66
+ tokens: int
67
+ tokens_total: int
55
68
 
56
- SubmitInputArgsT = TypeVar("SubmitInputArgsT", bound=AnyTypeDict)
57
- """
58
- A TypedDict representing the arguments that can be passed to the `.chat()`
59
- method of a [](`~chatlas.Chat`) instance.
60
- """
61
69
 
62
70
  CompletionT = TypeVar("CompletionT")
63
71
 
64
72
  EchoOptions = Literal["output", "all", "none", "text"]
65
73
 
74
+ T = TypeVar("T")
75
+
76
+
77
+ def is_present(value: T | None | MISSING_TYPE) -> TypeGuard[T]:
78
+ return value is not None and not isinstance(value, MISSING_TYPE)
79
+
66
80
 
67
81
  class Chat(Generic[SubmitInputArgsT, CompletionT]):
68
82
  """
@@ -82,7 +96,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
82
96
  def __init__(
83
97
  self,
84
98
  provider: Provider,
85
- turns: Optional[Sequence[Turn]] = None,
99
+ system_prompt: Optional[str] = None,
86
100
  ):
87
101
  """
88
102
  Create a new chat object.
@@ -91,11 +105,13 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
91
105
  ----------
92
106
  provider
93
107
  A [](`~chatlas.Provider`) object.
94
- turns
95
- A list of [](`~chatlas.Turn`) objects to initialize the chat with.
108
+ system_prompt
109
+ A system prompt to set the behavior of the assistant.
96
110
  """
97
111
  self.provider = provider
98
- self._turns: list[Turn] = list(turns or [])
112
+ self._turns: list[Turn] = []
113
+ self.system_prompt = system_prompt
114
+
99
115
  self._tools: dict[str, Tool] = {}
100
116
  self._on_tool_request_callbacks = CallbackManager()
101
117
  self._on_tool_result_callbacks = CallbackManager()
@@ -105,6 +121,11 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
105
121
  "rich_console": {},
106
122
  "css_styles": {},
107
123
  }
124
+ self._mcp_manager = MCPSessionManager()
125
+
126
+ # Chat input parameters from `set_model_params()`
127
+ self._standard_model_params: StandardModelParams = {}
128
+ self._submit_input_kwargs: Optional[SubmitInputArgsT] = None
108
129
 
109
130
  def get_turns(
110
131
  self,
@@ -149,8 +170,10 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
149
170
  """
150
171
  Set the turns of the chat.
151
172
 
152
- This method is primarily useful for clearing or setting the turns of the
153
- chat (i.e., limiting the context window).
173
+ Replaces the current chat history state (i.e., turns) with the provided turns.
174
+ This can be useful for:
175
+ * Clearing (or trimming) the chat history (i.e., `.set_turns([])`).
176
+ * Restoring context from a previous chat.
154
177
 
155
178
  Parameters
156
179
  ----------
@@ -165,7 +188,28 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
165
188
  "Consider removing this turn and setting the `.system_prompt` separately "
166
189
  "if you want to change the system prompt."
167
190
  )
168
- self._turns = list(turns)
191
+
192
+ turns_list = list(turns)
193
+ # Preserve the system prompt if it exists
194
+ if self._turns and self._turns[0].role == "system":
195
+ turns_list.insert(0, self._turns[0])
196
+ self._turns = turns_list
197
+
198
+ def add_turn(self, turn: Turn):
199
+ """
200
+ Add a turn to the chat.
201
+
202
+ Parameters
203
+ ----------
204
+ turn
205
+ The turn to add. Turns with the role "system" are not allowed.
206
+ """
207
+ if turn.role == "system":
208
+ raise ValueError(
209
+ "Turns with the role 'system' are not allowed. "
210
+ "The system prompt must be set separately using the `.system_prompt` property."
211
+ )
212
+ self._turns.append(turn)
169
213
 
170
214
  @property
171
215
  def system_prompt(self) -> str | None:
@@ -188,43 +232,14 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
188
232
  if value is not None:
189
233
  self._turns.insert(0, Turn("system", value))
190
234
 
191
- @overload
192
- def tokens(self) -> list[tuple[int, int] | None]: ...
193
-
194
- @overload
195
- def tokens(
196
- self,
197
- values: Literal["cumulative"],
198
- ) -> list[tuple[int, int] | None]: ...
199
-
200
- @overload
201
- def tokens(
202
- self,
203
- values: Literal["discrete"],
204
- ) -> list[int]: ...
205
-
206
- def tokens(
207
- self,
208
- values: Literal["cumulative", "discrete"] = "discrete",
209
- ) -> list[int] | list[tuple[int, int] | None]:
235
+ def get_tokens(self) -> list[TokensDict]:
210
236
  """
211
237
  Get the tokens for each turn in the chat.
212
238
 
213
- Parameters
214
- ----------
215
- values
216
- If "cumulative" (the default), the result can be summed to get the
217
- chat's overall token usage (helpful for computing overall cost of
218
- the chat). If "discrete", the result can be summed to get the number of
219
- tokens the turns will cost to generate the next response (helpful
220
- for estimating cost of the next response, or for determining if you
221
- are about to exceed the token limit).
222
-
223
239
  Returns
224
240
  -------
225
- list[int]
226
- A list of token counts for each (non-system) turn in the chat. The
227
- 1st turn includes the tokens count for the system prompt (if any).
241
+ list[TokensDict]
242
+ A list of dictionaries with the token counts for each (non-system) turn
228
243
 
229
244
  Raises
230
245
  ------
@@ -238,9 +253,6 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
238
253
 
239
254
  turns = self.get_turns(include_system_prompt=False)
240
255
 
241
- if values == "cumulative":
242
- return [turn.tokens for turn in turns]
243
-
244
256
  if len(turns) == 0:
245
257
  return []
246
258
 
@@ -276,12 +288,21 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
276
288
  "Expected the 1st assistant turn to contain token counts. " + err_info
277
289
  )
278
290
 
279
- res: list[int] = [
291
+ res: list[TokensDict] = [
280
292
  # Implied token count for the 1st user input
281
- turns[1].tokens[0],
293
+ {
294
+ "role": "user",
295
+ "tokens": turns[1].tokens[0],
296
+ "tokens_total": turns[1].tokens[0],
297
+ },
282
298
  # The token count for the 1st assistant response
283
- turns[1].tokens[1],
299
+ {
300
+ "role": "assistant",
301
+ "tokens": turns[1].tokens[1],
302
+ "tokens_total": turns[1].tokens[1],
303
+ },
284
304
  ]
305
+
285
306
  for i in range(1, len(turns) - 1, 2):
286
307
  ti = turns[i]
287
308
  tj = turns[i + 2]
@@ -296,15 +317,102 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
296
317
  )
297
318
  res.extend(
298
319
  [
299
- # Implied token count for the user input
300
- tj.tokens[0] - sum(ti.tokens),
301
- # The token count for the assistant response
302
- tj.tokens[1],
320
+ {
321
+ "role": "user",
322
+ # Implied token count for the user input
323
+ "tokens": tj.tokens[0] - sum(ti.tokens),
324
+ # Total tokens = Total User Tokens for the Turn = Distinct new tokens + context sent
325
+ "tokens_total": tj.tokens[0],
326
+ },
327
+ {
328
+ "role": "assistant",
329
+ # The token count for the assistant response
330
+ "tokens": tj.tokens[1],
331
+ # Total tokens = Total Assistant tokens used in the turn
332
+ "tokens_total": tj.tokens[1],
333
+ },
303
334
  ]
304
335
  )
305
336
 
306
337
  return res
307
338
 
339
+ def get_cost(
340
+ self,
341
+ options: Literal["all", "last"] = "all",
342
+ token_price: Optional[tuple[float, float]] = None,
343
+ ) -> float:
344
+ """
345
+ Estimate the cost of the chat.
346
+
347
+ Note
348
+ ----
349
+ This is a rough estimate, treat it as such. Providers may change their
350
+ pricing frequently and without notice.
351
+
352
+ Parameters
353
+ ----------
354
+ options
355
+ One of the following (default is "all"):
356
+ - `"all"`: Return the total cost of all turns in the chat.
357
+ - `"last"`: Return the cost of the last turn in the chat.
358
+ token_price
359
+ An optional tuple in the format of (input_token_cost,
360
+ output_token_cost) for bringing your own cost information.
361
+ - `"input_token_cost"`: The cost per user token in USD per
362
+ million tokens.
363
+ - `"output_token_cost"`: The cost per assistant token in USD
364
+ per million tokens.
365
+
366
+ Returns
367
+ -------
368
+ float
369
+ The cost of the chat, in USD.
370
+ """
371
+
372
+ # Look up token cost for user and input tokens based on the provider and model
373
+ turns_tokens = self.get_tokens()
374
+ if token_price:
375
+ input_token_price = token_price[0] / 1e6
376
+ output_token_price = token_price[1] / 1e6
377
+ else:
378
+ price_token = get_token_pricing(self.provider.name, self.provider.model)
379
+ if not price_token:
380
+ raise KeyError(
381
+ f"We could not locate pricing information for model '{self.provider.model}' from provider '{self.provider.name}'. "
382
+ "If you know the pricing for this model, specify it in `token_price`."
383
+ )
384
+ input_token_price = price_token["input"] / 1e6
385
+ output_token_price = price_token["output"] / 1e6
386
+
387
+ if len(turns_tokens) == 0:
388
+ return 0.0
389
+
390
+ if options not in ("all", "last"):
391
+ raise ValueError(
392
+ f"Expected `options` to be one of 'all' or 'last', not '{options}'"
393
+ )
394
+
395
+ if options == "all":
396
+ asst_tokens = sum(
397
+ u["tokens_total"] for u in turns_tokens if u["role"] == "assistant"
398
+ )
399
+ user_tokens = sum(
400
+ u["tokens_total"] for u in turns_tokens if u["role"] == "user"
401
+ )
402
+ cost = (asst_tokens * output_token_price) + (
403
+ user_tokens * input_token_price
404
+ )
405
+ return cost
406
+
407
+ last_turn = turns_tokens[-1]
408
+ if last_turn["role"] == "assistant":
409
+ return last_turn["tokens"] * output_token_price
410
+ if last_turn["role"] == "user":
411
+ return last_turn["tokens_total"] * input_token_price
412
+ raise ValueError(
413
+ f"Expected last turn to have a role of 'user' or `'assistant'`, not '{last_turn['role']}'"
414
+ )
415
+
308
416
  def token_count(
309
417
  self,
310
418
  *args: Content | str,
@@ -397,6 +505,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
397
505
  *,
398
506
  stream: bool = True,
399
507
  port: int = 0,
508
+ host: str = "127.0.0.1",
400
509
  launch_browser: bool = True,
401
510
  bg_thread: Optional[bool] = None,
402
511
  echo: Optional[EchoOptions] = None,
@@ -412,6 +521,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
412
521
  Whether to stream the response (i.e., have the response appear in chunks).
413
522
  port
414
523
  The port to run the app on (the default is 0, which will choose a random port).
524
+ host
525
+ The host to run the app on (the default is "127.0.0.1").
415
526
  launch_browser
416
527
  Whether to launch a browser window.
417
528
  bg_thread
@@ -479,7 +590,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
479
590
  app = App(app_ui, server)
480
591
 
481
592
  def _run_app():
482
- run_app(app, launch_browser=launch_browser, port=port)
593
+ run_app(app, launch_browser=launch_browser, port=port, host=host)
483
594
 
484
595
  # Use bg_thread by default in Jupyter and Positron
485
596
  if bg_thread is None:
@@ -910,10 +1021,422 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
910
1021
  json = res[0]
911
1022
  return json.value
912
1023
 
1024
+ def set_model_params(
1025
+ self,
1026
+ *,
1027
+ temperature: float | None | MISSING_TYPE = MISSING,
1028
+ top_p: float | None | MISSING_TYPE = MISSING,
1029
+ top_k: int | None | MISSING_TYPE = MISSING,
1030
+ frequency_penalty: float | None | MISSING_TYPE = MISSING,
1031
+ presence_penalty: float | None | MISSING_TYPE = MISSING,
1032
+ seed: int | None | MISSING_TYPE = MISSING,
1033
+ max_tokens: int | None | MISSING_TYPE = MISSING,
1034
+ log_probs: bool | None | MISSING_TYPE = MISSING,
1035
+ stop_sequences: list[str] | None | MISSING_TYPE = MISSING,
1036
+ kwargs: SubmitInputArgsT | None | MISSING_TYPE = MISSING,
1037
+ ):
1038
+ """
1039
+ Set common model parameters for the chat.
1040
+
1041
+ A unified interface for setting common model parameters
1042
+ across different providers. This method is useful for setting
1043
+ parameters that are commonly supported by most providers, such as
1044
+ temperature, top_p, etc.
1045
+
1046
+ By default, if the parameter is not set (i.e., set to `MISSING`),
1047
+ the provider's default value is used. If you want to reset a
1048
+ parameter to its default value, set it to `None`.
1049
+
1050
+ Parameters
1051
+ ----------
1052
+ temperature
1053
+ Temperature of the sampling distribution.
1054
+ top_p
1055
+ The cumulative probability for token selection.
1056
+ top_k
1057
+ The number of highest probability vocabulary tokens to keep.
1058
+ frequency_penalty
1059
+ Frequency penalty for generated tokens.
1060
+ presence_penalty
1061
+ Presence penalty for generated tokens.
1062
+ seed
1063
+ Seed for random number generator.
1064
+ max_tokens
1065
+ Maximum number of tokens to generate.
1066
+ log_probs
1067
+ Include the log probabilities in the output?
1068
+ stop_sequences
1069
+ A character vector of tokens to stop generation on.
1070
+ kwargs
1071
+ Additional keyword arguments to use when submitting input to the
1072
+ model. When calling this method repeatedly with different parameters,
1073
+ only the parameters from the last call will be used.
1074
+ """
1075
+
1076
+ params: StandardModelParams = {}
1077
+
1078
+ # Collect specified parameters
1079
+ if is_present(temperature):
1080
+ params["temperature"] = temperature
1081
+ if is_present(top_p):
1082
+ params["top_p"] = top_p
1083
+ if is_present(top_k):
1084
+ params["top_k"] = top_k
1085
+ if is_present(frequency_penalty):
1086
+ params["frequency_penalty"] = frequency_penalty
1087
+ if is_present(presence_penalty):
1088
+ params["presence_penalty"] = presence_penalty
1089
+ if is_present(seed):
1090
+ params["seed"] = seed
1091
+ if is_present(max_tokens):
1092
+ params["max_tokens"] = max_tokens
1093
+ if is_present(log_probs):
1094
+ params["log_probs"] = log_probs
1095
+ if is_present(stop_sequences):
1096
+ params["stop_sequences"] = stop_sequences
1097
+
1098
+ # Warn about un-supported parameters
1099
+ supported = self.provider.supported_model_params()
1100
+ unsupported = set(params.keys()) - set(supported)
1101
+ if unsupported:
1102
+ warnings.warn(
1103
+ f"The following parameters are not supported by the provider: {unsupported}. "
1104
+ "Please check the provider's documentation for supported parameters.",
1105
+ UserWarning,
1106
+ )
1107
+ # Drop the unsupported parameters
1108
+ for key in unsupported:
1109
+ del params[key]
1110
+
1111
+ # Drop parameters that are set to None
1112
+ discard = []
1113
+ if temperature is None:
1114
+ discard.append("temperature")
1115
+ if top_p is None:
1116
+ discard.append("top_p")
1117
+ if top_k is None:
1118
+ discard.append("top_k")
1119
+ if frequency_penalty is None:
1120
+ discard.append("frequency_penalty")
1121
+ if presence_penalty is None:
1122
+ discard.append("presence_penalty")
1123
+ if seed is None:
1124
+ discard.append("seed")
1125
+ if max_tokens is None:
1126
+ discard.append("max_tokens")
1127
+ if log_probs is None:
1128
+ discard.append("log_probs")
1129
+ if stop_sequences is None:
1130
+ discard.append("stop_sequences")
1131
+
1132
+ for key in discard:
1133
+ if key in self._standard_model_params:
1134
+ del self._standard_model_params[key]
1135
+
1136
+ # Update the standard model parameters
1137
+ self._standard_model_params.update(params)
1138
+
1139
+ # Update the submit input kwargs
1140
+ if kwargs is None:
1141
+ self._submit_input_kwargs = None
1142
+
1143
+ if is_present(kwargs):
1144
+ self._submit_input_kwargs = kwargs
1145
+
1146
+ async def register_mcp_tools_http_stream_async(
1147
+ self,
1148
+ *,
1149
+ url: str,
1150
+ include_tools: Sequence[str] = (),
1151
+ exclude_tools: Sequence[str] = (),
1152
+ name: Optional[str] = None,
1153
+ namespace: Optional[str] = None,
1154
+ transport_kwargs: Optional[dict[str, Any]] = None,
1155
+ ):
1156
+ """
1157
+ Register tools from an MCP server using streamable HTTP transport.
1158
+
1159
+ Connects to an MCP server (that communicates over a streamable HTTP
1160
+ transport) and registers the available tools. This is useful for
1161
+ utilizing tools provided by an MCP server running on a remote server (or
1162
+ locally) over HTTP.
1163
+
1164
+ Pre-requisites
1165
+ --------------
1166
+
1167
+ ::: {.callout-note}
1168
+ Requires the `mcp` package to be installed. Install it with:
1169
+
1170
+ ```bash
1171
+ pip install mcp
1172
+ ```
1173
+ :::
1174
+
1175
+ Parameters
1176
+ ----------
1177
+ url
1178
+ URL endpoint where the Streamable HTTP server is mounted (e.g.,
1179
+ `http://localhost:8000/mcp`)
1180
+ name
1181
+ A unique name for the MCP server session. If not provided, the name
1182
+ is derived from the MCP server information. This name is primarily
1183
+ useful for cleanup purposes (i.e., to close a particular MCP
1184
+ session).
1185
+ include_tools
1186
+ List of tool names to include. By default, all available tools are
1187
+ included.
1188
+ exclude_tools
1189
+ List of tool names to exclude. This parameter and `include_tools`
1190
+ are mutually exclusive.
1191
+ namespace
1192
+ A namespace to prepend to tool names (i.e., `namespace.tool_name`)
1193
+ from this MCP server. This is primarily useful to avoid name
1194
+ collisions with other tools already registered with the chat. This
1195
+ namespace applies when tools are advertised to the LLM, so try
1196
+ to use a meaningful name that describes the server and/or the tools
1197
+ it provides. For example, if you have a server that provides tools
1198
+ for mathematical operations, you might use `math` as the namespace.
1199
+ transport_kwargs
1200
+ Additional keyword arguments for the transport layer (i.e.,
1201
+ `mcp.client.streamable_http.streamablehttp_client`).
1202
+
1203
+ Returns
1204
+ -------
1205
+ None
1206
+
1207
+ See Also
1208
+ --------
1209
+ * `.cleanup_mcp_tools_async()` : Cleanup registered MCP tools.
1210
+ * `.register_mcp_tools_stdio_async()` : Register tools from an MCP server using stdio transport.
1211
+
1212
+ Note
1213
+ ----
1214
+ Unlike the `.register_mcp_tools_stdio_async()` method, this method does
1215
+ not launch an MCP server. Instead, it assumes an HTTP server is already
1216
+ running at the specified URL. This is useful for connecting to an
1217
+ existing MCP server that is already running and serving tools.
1218
+
1219
+ Examples
1220
+ --------
1221
+
1222
+ Assuming you have a Python script `my_mcp_server.py` that implements an
1223
+ MCP server like so:
1224
+
1225
+ ```python
1226
+ from mcp.server.fastmcp import FastMCP
1227
+
1228
+ app = FastMCP("my_server")
1229
+
1230
+ @app.tool(description="Add two numbers.")
1231
+ def add(x: int, y: int) -> int:
1232
+ return x + y
1233
+
1234
+ app.run(transport="streamable-http")
1235
+ ```
1236
+
1237
+ You can launch this server like so:
1238
+
1239
+ ```bash
1240
+ python my_mcp_server.py
1241
+ ```
1242
+
1243
+ Then, you can register this server with the chat as follows:
1244
+
1245
+ ```python
1246
+ await chat.register_mcp_tools_http_stream_async(
1247
+ url="http://localhost:8080/mcp"
1248
+ )
1249
+ ```
1250
+ """
1251
+ if isinstance(exclude_tools, str):
1252
+ exclude_tools = [exclude_tools]
1253
+ if isinstance(include_tools, str):
1254
+ include_tools = [include_tools]
1255
+
1256
+ session_info = await self._mcp_manager.register_http_stream_tools(
1257
+ name=name,
1258
+ url=url,
1259
+ include_tools=include_tools,
1260
+ exclude_tools=exclude_tools,
1261
+ namespace=namespace,
1262
+ transport_kwargs=transport_kwargs or {},
1263
+ )
1264
+
1265
+ overlapping_tools = set(self._tools.keys()) & set(session_info.tools)
1266
+ if overlapping_tools:
1267
+ await self._mcp_manager.close_sessions([session_info.name])
1268
+ raise ValueError(
1269
+ f"The following tools are already registered: {overlapping_tools}. "
1270
+ "Consider providing a namespace when registering this MCP server "
1271
+ "to avoid name collisions."
1272
+ )
1273
+
1274
+ self._tools.update(session_info.tools)
1275
+
1276
+ async def register_mcp_tools_stdio_async(
1277
+ self,
1278
+ *,
1279
+ command: str,
1280
+ args: list[str],
1281
+ name: Optional[str] = None,
1282
+ include_tools: Sequence[str] = (),
1283
+ exclude_tools: Sequence[str] = (),
1284
+ namespace: Optional[str] = None,
1285
+ transport_kwargs: Optional[dict[str, Any]] = None,
1286
+ ):
1287
+ """
1288
+ Register tools from a MCP server using stdio (standard input/output) transport.
1289
+
1290
+ Useful for launching an MCP server and registering its tools with the chat -- all
1291
+ from the same Python process.
1292
+
1293
+ In more detail, this method:
1294
+
1295
+ 1. Executes the given `command` with the provided `args`.
1296
+ * This should start an MCP server that communicates via stdio.
1297
+ 2. Establishes a client connection to the MCP server using the `mcp` package.
1298
+ 3. Registers the available tools from the MCP server with the chat.
1299
+ 4. Returns a cleanup callback to close the MCP session and remove the tools.
1300
+
1301
+ Pre-requisites
1302
+ --------------
1303
+
1304
+ ::: {.callout-note}
1305
+ Requires the `mcp` package to be installed. Install it with:
1306
+
1307
+ ```bash
1308
+ pip install mcp
1309
+ ```
1310
+ :::
1311
+
1312
+ Parameters
1313
+ ----------
1314
+ command
1315
+ System command to execute to start the MCP server (e.g., `python`).
1316
+ args
1317
+ Arguments to pass to the system command (e.g., `["-m",
1318
+ "my_mcp_server"]`).
1319
+ name
1320
+ A unique name for the MCP server session. If not provided, the name
1321
+ is derived from the MCP server information. This name is primarily
1322
+ useful for cleanup purposes (i.e., to close a particular MCP
1323
+ session).
1324
+ include_tools
1325
+ List of tool names to include. By default, all available tools are
1326
+ included.
1327
+ exclude_tools
1328
+ List of tool names to exclude. This parameter and `include_tools`
1329
+ are mutually exclusive.
1330
+ namespace
1331
+ A namespace to prepend to tool names (i.e., `namespace.tool_name`)
1332
+ from this MCP server. This is primarily useful to avoid name
1333
+ collisions with other tools already registered with the chat. This
1334
+ namespace applies when tools are advertised to the LLM, so try
1335
+ to use a meaningful name that describes the server and/or the tools
1336
+ it provides. For example, if you have a server that provides tools
1337
+ for mathematical operations, you might use `math` as the namespace.
1338
+ transport_kwargs
1339
+ Additional keyword arguments for the stdio transport layer (i.e.,
1340
+ `mcp.client.stdio.stdio_client`).
1341
+
1342
+ Returns
1343
+ -------
1344
+ None
1345
+
1346
+ See Also
1347
+ --------
1348
+ * `.cleanup_mcp_tools_async()` : Cleanup registered MCP tools.
1349
+ * `.register_mcp_tools_http_stream_async()` : Register tools from an MCP server using streamable HTTP transport.
1350
+
1351
+ Examples
1352
+ --------
1353
+
1354
+ Assuming you have a Python script `my_mcp_server.py` that implements an
1355
+ MCP server like so
1356
+
1357
+ ```python
1358
+ from mcp.server.fastmcp import FastMCP
1359
+
1360
+ app = FastMCP("my_server")
1361
+
1362
+ @app.tool(description="Add two numbers.")
1363
+ def add(y: int, z: int) -> int:
1364
+ return y - z
1365
+
1366
+ app.run(transport="stdio")
1367
+ ```
1368
+
1369
+ You can register this server with the chat as follows:
1370
+
1371
+ ```python
1372
+ from chatlas import ChatOpenAI
1373
+
1374
+ chat = ChatOpenAI()
1375
+
1376
+ await chat.register_mcp_tools_stdio_async(
1377
+ command="python",
1378
+ args=["-m", "my_mcp_server"],
1379
+ )
1380
+ ```
1381
+ """
1382
+ if isinstance(exclude_tools, str):
1383
+ exclude_tools = [exclude_tools]
1384
+ if isinstance(include_tools, str):
1385
+ include_tools = [include_tools]
1386
+
1387
+ session_info = await self._mcp_manager.register_stdio_tools(
1388
+ command=command,
1389
+ args=args,
1390
+ name=name,
1391
+ include_tools=include_tools,
1392
+ exclude_tools=exclude_tools,
1393
+ namespace=namespace,
1394
+ transport_kwargs=transport_kwargs or {},
1395
+ )
1396
+
1397
+ overlapping_tools = set(self._tools.keys()) & set(session_info.tools)
1398
+ if overlapping_tools:
1399
+ await self._mcp_manager.close_sessions([session_info.name])
1400
+ raise ValueError(
1401
+ f"The following tools are already registered: {overlapping_tools}. "
1402
+ "Consider providing a namespace when registering this MCP server "
1403
+ "to avoid name collisions."
1404
+ )
1405
+
1406
+ self._tools.update(session_info.tools)
1407
+
1408
+ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None):
1409
+ """
1410
+ Close MCP server connections (and their corresponding tools).
1411
+
1412
+ This method closes the MCP client sessions and removes the tools registered
1413
+ from the MCP servers. If a specific `name` is provided, it will only clean
1414
+ up the tools and session associated with that name. If no name is provided,
1415
+ it will clean up all registered MCP tools and sessions.
1416
+
1417
+ Parameters
1418
+ ----------
1419
+ names
1420
+ If provided, only clean up the tools and session associated
1421
+ with these names. If not provided, clean up all registered MCP tools and sessions.
1422
+
1423
+ Returns
1424
+ -------
1425
+ None
1426
+ """
1427
+ closed_sessions = await self._mcp_manager.close_sessions(names)
1428
+
1429
+ # Remove relevant MCP tools from the main tools registry
1430
+ for session in closed_sessions:
1431
+ for tool_name in session.tools:
1432
+ if tool_name in self._tools:
1433
+ del self._tools[tool_name]
1434
+
913
1435
  def register_tool(
914
1436
  self,
915
1437
  func: Callable[..., Any] | Callable[..., Awaitable[Any]],
916
1438
  *,
1439
+ force: bool = False,
917
1440
  model: Optional[type[BaseModel]] = None,
918
1441
  ):
919
1442
  """
@@ -930,7 +1453,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
930
1453
  recommended):
931
1454
 
932
1455
  ```python
933
- from chatlas import ChatOpenAI, Tool
1456
+ from chatlas import ChatOpenAI
934
1457
 
935
1458
 
936
1459
  def add(a: int, b: int) -> int:
@@ -958,7 +1481,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
958
1481
  and also more directly document the input parameters:
959
1482
 
960
1483
  ```python
961
- from chatlas import ChatOpenAI, Tool
1484
+ from chatlas import ChatOpenAI
962
1485
  from pydantic import BaseModel, Field
963
1486
 
964
1487
 
@@ -983,16 +1506,62 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
983
1506
  ----------
984
1507
  func
985
1508
  The function to be invoked when the tool is called.
1509
+ force
1510
+ If `True`, overwrite any existing tool with the same name. If `False`
1511
+ (the default), raise an error if a tool with the same name already exists.
986
1512
  model
987
1513
  A Pydantic model that describes the input parameters for the function.
988
1514
  If not provided, the model will be inferred from the function's type hints.
989
1515
  The primary reason why you might want to provide a model in
990
1516
  Note that the name and docstring of the model takes precedence over the
991
1517
  name and docstring of the function.
1518
+
1519
+ Raises
1520
+ ------
1521
+ ValueError
1522
+ If a tool with the same name already exists and `force` is `False`.
992
1523
  """
993
- tool = Tool(func, model=model)
1524
+ tool = Tool.from_func(func, model=model)
1525
+ if tool.name in self._tools and not force:
1526
+ raise ValueError(
1527
+ f"Tool with name '{tool.name}' is already registered. "
1528
+ "Set `force=True` to overwrite it."
1529
+ )
994
1530
  self._tools[tool.name] = tool
995
1531
 
1532
+ def get_tools(self) -> list[Tool]:
1533
+ """
1534
+ Get the list of registered tools.
1535
+
1536
+ Returns
1537
+ -------
1538
+ list[Tool]
1539
+ A list of `Tool` instances that are currently registered with the chat.
1540
+ """
1541
+ return list(self._tools.values())
1542
+
1543
+ def set_tools(
1544
+ self, tools: list[Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool]
1545
+ ):
1546
+ """
1547
+ Set the tools for the chat.
1548
+
1549
+ This replaces any previously registered tools with the provided list of
1550
+ tools. This is for advanced usage -- typically, you would use
1551
+ `.register_tool()` to register individual tools as needed.
1552
+
1553
+ Parameters
1554
+ ----------
1555
+ tools
1556
+ A list of `Tool` instances to set as the chat's tools.
1557
+ """
1558
+ self._tools = {}
1559
+ for tool in tools:
1560
+ if isinstance(tool, Tool):
1561
+ self._tools[tool.name] = tool
1562
+ else:
1563
+ self.register_tool(tool)
1564
+
996
1565
  def on_tool_request(self, callback: Callable[[ContentToolRequest], None]):
997
1566
  """
998
1567
  Register a callback for a tool request event.
@@ -1257,22 +1826,23 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1257
1826
  assert turn is not None
1258
1827
  user_turn_result = None
1259
1828
 
1260
- results: list[ContentToolResult] = []
1829
+ all_results: list[ContentToolResult] = []
1261
1830
  for x in turn.contents:
1262
1831
  if isinstance(x, ContentToolRequest):
1263
1832
  if echo == "output":
1264
1833
  self._echo_content(f"\n\n{x}\n\n")
1265
1834
  if content == "all":
1266
1835
  yield x
1267
- res = self._invoke_tool(x)
1268
- if echo == "output":
1269
- self._echo_content(f"\n\n{res}\n\n")
1270
- if content == "all":
1271
- yield res
1272
- results.append(res)
1836
+ results = self._invoke_tool(x)
1837
+ for res in results:
1838
+ if echo == "output":
1839
+ self._echo_content(f"\n\n{res}\n\n")
1840
+ if content == "all":
1841
+ yield res
1842
+ all_results.append(res)
1273
1843
 
1274
- if results:
1275
- user_turn_result = Turn("user", results)
1844
+ if all_results:
1845
+ user_turn_result = Turn("user", all_results)
1276
1846
 
1277
1847
  @overload
1278
1848
  def _chat_impl_async(
@@ -1316,24 +1886,25 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1316
1886
  assert turn is not None
1317
1887
  user_turn_result = None
1318
1888
 
1319
- results: list[ContentToolResult] = []
1889
+ all_results: list[ContentToolResult] = []
1320
1890
  for x in turn.contents:
1321
1891
  if isinstance(x, ContentToolRequest):
1322
1892
  if echo == "output":
1323
1893
  self._echo_content(f"\n\n{x}\n\n")
1324
1894
  if content == "all":
1325
1895
  yield x
1326
- res = await self._invoke_tool_async(x)
1327
- if echo == "output":
1328
- self._echo_content(f"\n\n{res}\n\n")
1329
- if content == "all":
1330
- yield res
1331
- else:
1332
- yield "\n\n"
1333
- results.append(res)
1334
-
1335
- if results:
1336
- user_turn_result = Turn("user", results)
1896
+ results = self._invoke_tool_async(x)
1897
+ async for res in results:
1898
+ if echo == "output":
1899
+ self._echo_content(f"\n\n{res}\n\n")
1900
+ if content == "all":
1901
+ yield res
1902
+ else:
1903
+ yield "\n\n"
1904
+ all_results.append(res)
1905
+
1906
+ if all_results:
1907
+ user_turn_result = Turn("user", all_results)
1337
1908
 
1338
1909
  def _submit_turns(
1339
1910
  self,
@@ -1354,13 +1925,25 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1354
1925
  if echo == "all":
1355
1926
  emit_user_contents(user_turn, emit)
1356
1927
 
1928
+ # Start collecting additional keyword args (from model parameters)
1929
+ all_kwargs = self.provider.translate_model_params(
1930
+ params=self._standard_model_params,
1931
+ )
1932
+
1933
+ # Add any additional kwargs provided by the user
1934
+ if self._submit_input_kwargs:
1935
+ all_kwargs.update(self._submit_input_kwargs)
1936
+
1937
+ if kwargs:
1938
+ all_kwargs.update(kwargs)
1939
+
1357
1940
  if stream:
1358
1941
  response = self.provider.chat_perform(
1359
1942
  stream=True,
1360
1943
  turns=[*self._turns, user_turn],
1361
1944
  tools=self._tools,
1362
1945
  data_model=data_model,
1363
- kwargs=kwargs,
1946
+ kwargs=all_kwargs,
1364
1947
  )
1365
1948
 
1366
1949
  result = None
@@ -1385,7 +1968,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1385
1968
  turns=[*self._turns, user_turn],
1386
1969
  tools=self._tools,
1387
1970
  data_model=data_model,
1388
- kwargs=kwargs,
1971
+ kwargs=all_kwargs,
1389
1972
  )
1390
1973
 
1391
1974
  turn = self.provider.value_turn(
@@ -1462,54 +2045,56 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1462
2045
 
1463
2046
  self._turns.extend([user_turn, turn])
1464
2047
 
1465
- def _invoke_tool(self, x: ContentToolRequest) -> ContentToolResult:
1466
- tool_def = self._tools.get(x.name, None)
2048
+ def _invoke_tool(self, request: ContentToolRequest):
2049
+ tool_def = self._tools.get(request.name, None)
1467
2050
  func = tool_def.func if tool_def is not None else None
1468
2051
 
1469
2052
  if func is None:
1470
- e = RuntimeError(f"Unknown tool: {x.name}")
1471
- return ContentToolResult(value=None, error=e, request=x)
2053
+ yield self._handle_tool_error_result(
2054
+ request,
2055
+ error=RuntimeError("Unknown tool."),
2056
+ )
2057
+ return
1472
2058
 
1473
2059
  # First, invoke the request callbacks. If a ToolRejectError is raised,
1474
2060
  # treat it like a tool failure (i.e., gracefully handle it).
1475
2061
  result: ContentToolResult | None = None
1476
2062
  try:
1477
- self._on_tool_request_callbacks.invoke(x)
2063
+ self._on_tool_request_callbacks.invoke(request)
1478
2064
  except ToolRejectError as e:
1479
- result = ContentToolResult(value=None, error=e, request=x)
2065
+ yield self._handle_tool_error_result(request, e)
2066
+ return
1480
2067
 
1481
- # Invoke the tool (if it hasn't been rejected).
1482
- if result is None:
1483
- try:
1484
- if isinstance(x.arguments, dict):
1485
- res = func(**x.arguments)
1486
- else:
1487
- res = func(x.arguments)
2068
+ try:
2069
+ if isinstance(request.arguments, dict):
2070
+ res = func(**request.arguments)
2071
+ else:
2072
+ res = func(request.arguments)
2073
+
2074
+ # Normalize res as a generator of results.
2075
+ if not inspect.isgenerator(res):
2076
+
2077
+ def _as_generator(res):
2078
+ yield res
1488
2079
 
1489
- if isinstance(res, ContentToolResult):
1490
- result = res
2080
+ res = _as_generator(res)
2081
+
2082
+ for x in res:
2083
+ if isinstance(x, ContentToolResult):
2084
+ result = x
1491
2085
  else:
1492
- result = ContentToolResult(value=res)
2086
+ result = ContentToolResult(value=x)
1493
2087
 
1494
- result.request = x
1495
- except Exception as e:
1496
- result = ContentToolResult(value=None, error=e, request=x)
2088
+ result.request = request
1497
2089
 
1498
- # If we've captured an error, notify and log it.
1499
- if result.error:
1500
- warnings.warn(
1501
- f"Calling tool '{x.name}' led to an error.",
1502
- ToolFailureWarning,
1503
- stacklevel=2,
1504
- )
1505
- traceback.print_exc()
1506
- log_tool_error(x.name, str(x.arguments), result.error)
2090
+ self._on_tool_result_callbacks.invoke(result)
2091
+ yield result
1507
2092
 
1508
- self._on_tool_result_callbacks.invoke(result)
1509
- return result
2093
+ except Exception as e:
2094
+ yield self._handle_tool_error_result(request, e)
1510
2095
 
1511
- async def _invoke_tool_async(self, x: ContentToolRequest) -> ContentToolResult:
1512
- tool_def = self._tools.get(x.name, None)
2096
+ async def _invoke_tool_async(self, request: ContentToolRequest):
2097
+ tool_def = self._tools.get(request.name, None)
1513
2098
  func = None
1514
2099
  if tool_def:
1515
2100
  if tool_def._is_async:
@@ -1518,45 +2103,59 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1518
2103
  func = wrap_async(tool_def.func)
1519
2104
 
1520
2105
  if func is None:
1521
- e = RuntimeError(f"Unknown tool: {x.name}")
1522
- return ContentToolResult(value=None, error=e, request=x)
2106
+ yield self._handle_tool_error_result(
2107
+ request,
2108
+ error=RuntimeError("Unknown tool."),
2109
+ )
2110
+ return
1523
2111
 
1524
2112
  # First, invoke the request callbacks. If a ToolRejectError is raised,
1525
2113
  # treat it like a tool failure (i.e., gracefully handle it).
1526
2114
  result: ContentToolResult | None = None
1527
2115
  try:
1528
- await self._on_tool_request_callbacks.invoke_async(x)
2116
+ await self._on_tool_request_callbacks.invoke_async(request)
1529
2117
  except ToolRejectError as e:
1530
- result = ContentToolResult(value=None, error=e, request=x)
2118
+ yield self._handle_tool_error_result(request, e)
2119
+ return
1531
2120
 
1532
2121
  # Invoke the tool (if it hasn't been rejected).
1533
- if result is None:
1534
- try:
1535
- if isinstance(x.arguments, dict):
1536
- res = await func(**x.arguments)
1537
- else:
1538
- res = await func(x.arguments)
2122
+ try:
2123
+ if isinstance(request.arguments, dict):
2124
+ res = await func(**request.arguments)
2125
+ else:
2126
+ res = await func(request.arguments)
2127
+
2128
+ # Normalize res into a generator of results.
2129
+ if not inspect.isasyncgen(res):
2130
+
2131
+ async def _as_async_generator(res):
2132
+ yield res
1539
2133
 
1540
- if isinstance(res, ContentToolResult):
1541
- result = res
2134
+ res = _as_async_generator(res)
2135
+
2136
+ async for x in res:
2137
+ if isinstance(x, ContentToolResult):
2138
+ result = x
1542
2139
  else:
1543
- result = ContentToolResult(value=res)
2140
+ result = ContentToolResult(value=x)
1544
2141
 
1545
- result.request = x
1546
- except Exception as e:
1547
- result = ContentToolResult(value=None, error=e, request=x)
2142
+ result.request = request
2143
+ await self._on_tool_result_callbacks.invoke_async(result)
2144
+ yield result
1548
2145
 
1549
- # If we've captured an error, notify and log it.
1550
- if result.error:
1551
- warnings.warn(
1552
- f"Calling tool '{x.name}' led to an error.",
1553
- ToolFailureWarning,
1554
- stacklevel=2,
1555
- )
1556
- traceback.print_exc()
1557
- log_tool_error(x.name, str(x.arguments), result.error)
2146
+ except Exception as e:
2147
+ yield self._handle_tool_error_result(request, e)
1558
2148
 
1559
- await self._on_tool_result_callbacks.invoke_async(result)
2149
+ def _handle_tool_error_result(self, request: ContentToolRequest, error: Exception):
2150
+ warnings.warn(
2151
+ f"Calling tool '{request.name}' led to an error: {error}",
2152
+ ToolFailureWarning,
2153
+ stacklevel=2,
2154
+ )
2155
+ traceback.print_exc()
2156
+ log_tool_error(request.name, str(request.arguments), error)
2157
+ result = ContentToolResult(value=None, error=error, request=request)
2158
+ self._on_tool_result_callbacks.invoke(result)
1560
2159
  return result
1561
2160
 
1562
2161
  def _markdown_display(self, echo: EchoOptions) -> ChatMarkdownDisplay:
@@ -1571,10 +2170,10 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1571
2170
  return ChatMarkdownDisplay(MockMarkdownDisplay(), self)
1572
2171
 
1573
2172
  # rich does a lot to detect a notebook environment, but it doesn't
1574
- # detect Quarto (at least not yet).
2173
+ # detect Quarto, or a Positron notebook
1575
2174
  from rich.console import Console
1576
2175
 
1577
- is_web = Console().is_jupyter or os.getenv("QUARTO_PYTHON", None) is not None
2176
+ is_web = Console().is_jupyter or is_quarto() or is_positron_notebook()
1578
2177
 
1579
2178
  opts = self._echo_options
1580
2179
 
@@ -1622,8 +2221,23 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1622
2221
 
1623
2222
  def __repr__(self):
1624
2223
  turns = self.get_turns(include_system_prompt=True)
1625
- tokens = sum(sum(turn.tokens) for turn in turns if turn.tokens)
1626
- res = f"<Chat turns={len(turns)} tokens={tokens}>"
2224
+ tokens = self.get_tokens()
2225
+ tokens_asst = sum(u["tokens_total"] for u in tokens if u["role"] == "assistant")
2226
+ tokens_user = sum(u["tokens_total"] for u in tokens if u["role"] == "user")
2227
+
2228
+ res = f"<Chat {self.provider.name}/{self.provider.model} turns={len(turns)} tokens={tokens_user}/{tokens_asst}"
2229
+
2230
+ # Add cost info only if we can compute it
2231
+ cost = compute_cost(
2232
+ self.provider.name,
2233
+ self.provider.model,
2234
+ tokens_user,
2235
+ tokens_asst,
2236
+ )
2237
+ if cost is not None:
2238
+ res += f" ${round(cost, ndigits=2)}"
2239
+
2240
+ res += ">"
1627
2241
  for turn in turns:
1628
2242
  res += "\n" + turn.__repr__(indent=2)
1629
2243
  return res + "\n"
@@ -1818,3 +2432,15 @@ class ToolFailureWarning(RuntimeWarning):
1818
2432
 
1819
2433
  # By default warnings are shown once; we want to always show them.
1820
2434
  warnings.simplefilter("always", ToolFailureWarning)
2435
+
2436
+
2437
+ def is_quarto():
2438
+ return os.getenv("QUARTO_PYTHON", None) is not None
2439
+
2440
+
2441
+ def is_positron_notebook():
2442
+ try:
2443
+ mode = get_ipython().session_mode # noqa: F821 # type: ignore
2444
+ return mode == "notebook"
2445
+ except Exception:
2446
+ return False