chatlas 0.7.1__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chatlas might be problematic. Click here for more details.

chatlas/_snowflake.py CHANGED
@@ -1,32 +1,60 @@
1
- import asyncio
2
- import json
3
- from typing import TYPE_CHECKING, Iterable, Literal, Optional, TypedDict, cast, overload
4
-
1
+ from typing import (
2
+ TYPE_CHECKING,
3
+ Generator,
4
+ Literal,
5
+ Optional,
6
+ TypedDict,
7
+ Union,
8
+ overload,
9
+ )
10
+
11
+ import orjson
5
12
  from pydantic import BaseModel
6
13
 
7
14
  from ._chat import Chat
8
- from ._content import Content, ContentJson, ContentText
15
+ from ._content import (
16
+ Content,
17
+ ContentJson,
18
+ ContentText,
19
+ ContentToolRequest,
20
+ ContentToolResult,
21
+ )
9
22
  from ._logging import log_model_default
10
23
  from ._provider import Provider
24
+ from ._tokens import tokens_log
11
25
  from ._tools import Tool, basemodel_to_param_schema
12
26
  from ._turn import Turn, normalize_turns
13
- from ._utils import drop_none, wrap_async_iterable
27
+ from ._utils import drop_none
14
28
 
15
29
  if TYPE_CHECKING:
16
- from snowflake.snowpark import Column
30
+ import snowflake.core.cortex.inference_service._generated.models as models
31
+ from snowflake.core.rest import Event, SSEClient
32
+
33
+ Completion = models.NonStreamingCompleteResponse
34
+ CompletionChunk = models.StreamingCompleteResponseDataEvent
35
+
36
+ # Manually constructed TypedDict equivalent of models.CompleteRequest
37
+ class CompleteRequest(TypedDict, total=False):
38
+ """
39
+ CompleteRequest parameters for Snowflake Cortex LLMs.
40
+
41
+ See `snowflake.core.cortex.inference_service.CompleteRequest` for more details.
42
+ """
43
+
44
+ temperature: Union[float, int]
45
+ """Temperature controls the amount of randomness used in response generation. A higher temperature corresponds to more randomness."""
17
46
 
18
- # Types inferred from the return type of the `snowflake.cortex.complete` function
19
- Completion = str | Column
20
- CompletionChunk = str
47
+ top_p: Union[float, int]
48
+ """Threshold probability for nucleus sampling. A higher top-p value increases the diversity of tokens that the model considers, while a lower value results in more predictable output."""
21
49
 
22
- from .types.snowflake import SubmitInputArgs
50
+ max_tokens: int
51
+ """The maximum number of output tokens to produce. The default value is model-dependent."""
23
52
 
53
+ guardrails: models.GuardrailsConfig
54
+ """Controls whether guardrails are enabled."""
24
55
 
25
- # The main prompt input type for Snowflake
26
- # This was copy-pasted from `snowflake.cortex._complete.ConversationMessage`
27
- class ConversationMessage(TypedDict):
28
- role: str
29
- content: str
56
+ tool_choice: models.ToolChoice
57
+ """Determines how tools are selected."""
30
58
 
31
59
 
32
60
  def ChatSnowflake(
@@ -41,7 +69,7 @@ def ChatSnowflake(
41
69
  private_key_file: Optional[str] = None,
42
70
  private_key_file_pwd: Optional[str] = None,
43
71
  kwargs: Optional[dict[str, "str | int"]] = None,
44
- ) -> Chat["SubmitInputArgs", "Completion"]:
72
+ ) -> Chat["CompleteRequest", "Completion"]:
45
73
  """
46
74
  Chat with a Snowflake Cortex LLM
47
75
 
@@ -116,7 +144,7 @@ def ChatSnowflake(
116
144
  """
117
145
 
118
146
  if model is None:
119
- model = log_model_default("llama3.1-70b")
147
+ model = log_model_default("claude-3-7-sonnet")
120
148
 
121
149
  return Chat(
122
150
  provider=SnowflakeProvider(
@@ -150,6 +178,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
150
178
  kwargs: Optional[dict[str, "str | int"]],
151
179
  ):
152
180
  try:
181
+ from snowflake.core import Root
153
182
  from snowflake.snowpark import Session
154
183
  except ImportError:
155
184
  raise ImportError(
@@ -170,7 +199,9 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
170
199
  )
171
200
 
172
201
  self._model = model
173
- self._session = Session.builder.configs(configs).create()
202
+
203
+ session = Session.builder.configs(configs).create()
204
+ self._cortex_service = Root(session).cortex_inference_service
174
205
 
175
206
  @overload
176
207
  def chat_perform(
@@ -180,7 +211,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
180
211
  turns: list[Turn],
181
212
  tools: dict[str, Tool],
182
213
  data_model: Optional[type[BaseModel]] = None,
183
- kwargs: Optional["SubmitInputArgs"] = None,
214
+ kwargs: Optional["CompleteRequest"] = None,
184
215
  ): ...
185
216
 
186
217
  @overload
@@ -191,7 +222,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
191
222
  turns: list[Turn],
192
223
  tools: dict[str, Tool],
193
224
  data_model: Optional[type[BaseModel]] = None,
194
- kwargs: Optional["SubmitInputArgs"] = None,
225
+ kwargs: Optional["CompleteRequest"] = None,
195
226
  ): ...
196
227
 
197
228
  def chat_perform(
@@ -201,12 +232,25 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
201
232
  turns: list[Turn],
202
233
  tools: dict[str, Tool],
203
234
  data_model: Optional[type[BaseModel]] = None,
204
- kwargs: Optional["SubmitInputArgs"] = None,
235
+ kwargs: Optional["CompleteRequest"] = None,
205
236
  ):
206
- from snowflake.cortex import complete
237
+ req = self._complete_request(stream, turns, tools, data_model, kwargs)
238
+ client = self._cortex_service.complete(req)
239
+
240
+ try:
241
+ events = client.events()
242
+ except Exception as e:
243
+ data = parse_request_object(client)
244
+ if data is None:
245
+ raise e
246
+ return data
207
247
 
208
- kwargs = self._chat_perform_args(stream, turns, tools, data_model, kwargs)
209
- return complete(**kwargs)
248
+ if stream:
249
+ return generate_event_data(events)
250
+
251
+ for evt in events:
252
+ if evt.data:
253
+ return parse_event_data(evt.data, stream=False)
210
254
 
211
255
  @overload
212
256
  async def chat_perform_async(
@@ -216,7 +260,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
216
260
  turns: list[Turn],
217
261
  tools: dict[str, Tool],
218
262
  data_model: Optional[type[BaseModel]] = None,
219
- kwargs: Optional["SubmitInputArgs"] = None,
263
+ kwargs: Optional["CompleteRequest"] = None,
220
264
  ): ...
221
265
 
222
266
  @overload
@@ -227,7 +271,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
227
271
  turns: list[Turn],
228
272
  tools: dict[str, Tool],
229
273
  data_model: Optional[type[BaseModel]] = None,
230
- kwargs: Optional["SubmitInputArgs"] = None,
274
+ kwargs: Optional["CompleteRequest"] = None,
231
275
  ): ...
232
276
 
233
277
  async def chat_perform_async(
@@ -237,65 +281,164 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
237
281
  turns: list[Turn],
238
282
  tools: dict[str, Tool],
239
283
  data_model: Optional[type[BaseModel]] = None,
240
- kwargs: Optional["SubmitInputArgs"] = None,
284
+ kwargs: Optional["CompleteRequest"] = None,
241
285
  ):
242
- from snowflake.cortex import complete
243
-
244
- kwargs = self._chat_perform_args(stream, turns, tools, data_model, kwargs)
286
+ req = self._complete_request(stream, turns, tools, data_model, kwargs)
287
+ res = self._cortex_service.complete_async(req)
288
+ # TODO: is there a way to get the SSEClient result without blocking?
289
+ client = res.result()
245
290
 
246
- # Prevent the main thread from being blocked (Snowflake doesn't have native async support)
247
- res = await asyncio.to_thread(complete, **kwargs)
291
+ try:
292
+ events = client.events()
293
+ except Exception as e:
294
+ data = parse_request_object(client)
295
+ if data is None:
296
+ raise e
297
+ return data
248
298
 
249
- # When streaming, res is an iterable of strings, but Chat() wants an async iterable
250
299
  if stream:
251
- res = wrap_async_iterable(cast(Iterable[str], res))
300
+ return generate_event_data_async(events)
252
301
 
253
- return res
302
+ for evt in events:
303
+ if evt.data:
304
+ return parse_event_data(evt.data, stream=False)
254
305
 
255
- def _chat_perform_args(
306
+ def _complete_request(
256
307
  self,
257
308
  stream: bool,
258
309
  turns: list[Turn],
259
310
  tools: dict[str, Tool],
260
311
  data_model: Optional[type[BaseModel]] = None,
261
- kwargs: Optional["SubmitInputArgs"] = None,
312
+ kwargs: Optional["CompleteRequest"] = None,
262
313
  ):
263
- kwargs_full: "SubmitInputArgs" = {
264
- "stream": stream,
265
- "prompt": self._as_prompt_input(turns),
266
- "model": self._model,
267
- "session": self._session,
268
- **(kwargs or {}),
269
- }
270
-
271
- # TODO: get tools working
314
+ from snowflake.core.cortex.inference_service import CompleteRequest
315
+
316
+ req = CompleteRequest(
317
+ model=self._model,
318
+ messages=self._as_request_messages(turns),
319
+ stream=stream,
320
+ )
321
+
272
322
  if tools:
273
- raise ValueError("Snowflake does not currently support tools.")
323
+ req.tools = req.tools or []
324
+ snow_tools = [self._as_snowflake_tool(tool) for tool in tools.values()]
325
+ req.tools.extend(snow_tools)
274
326
 
275
327
  if data_model is not None:
328
+ import snowflake.core.cortex.inference_service._generated.models as models
329
+
276
330
  params = basemodel_to_param_schema(data_model)
277
- opts = kwargs_full.get("options") or {}
278
- opts["response_format"] = {
279
- "type": "json",
280
- "schema": {
331
+ req.response_format = models.CompleteRequestResponseFormat(
332
+ type="json",
333
+ schema={
281
334
  "type": "object",
282
335
  "properties": params["properties"],
283
336
  "required": params["required"],
284
337
  },
285
- }
286
- kwargs_full["options"] = opts
338
+ )
287
339
 
288
- return kwargs_full
340
+ if kwargs:
341
+ for k, v in kwargs.items():
342
+ if hasattr(req, k):
343
+ setattr(req, k, v)
344
+ else:
345
+ raise ValueError(
346
+ f"Unknown parameter {k} for Snowflake CompleteRequest. "
347
+ "Please check the Snowflake documentation for valid parameters."
348
+ )
289
349
 
290
- def stream_text(self, chunk):
291
- return chunk
350
+ return req
292
351
 
352
+ def stream_text(self, chunk):
353
+ if not chunk.choices:
354
+ return None
355
+ delta = chunk.choices[0].delta
356
+ if delta is None or "content" not in delta:
357
+ return None
358
+ return delta["content"]
359
+
360
+ # Snowflake sort-of follows OpenAI/Anthropic streaming formats except they
361
+ # don't have the critical "index" field in the delta that the merge logic
362
+ # depends on (i.e., OpenAI), or official start/stop events (i.e.,
363
+ # Anthropic). So we have to do some janky merging here.
364
+ #
365
+ # This was done in a panic to get working asap, so don't judge :) I wouldn't
366
+ # be surprised if Snowflake realizes how bad this streaming format is and
367
+ # changes it in the future (thus probably breaking this code :( ).
293
368
  def stream_merge_chunks(self, completion, chunk):
294
369
  if completion is None:
295
370
  return chunk
296
- return completion + chunk
371
+
372
+ if completion.choices is None or chunk.choices is None:
373
+ raise ValueError(
374
+ "Unexpected None for completion.choices. Please report this issue."
375
+ )
376
+
377
+ if completion.choices[0].delta is None or chunk.choices[0].delta is None:
378
+ raise ValueError(
379
+ "Unexpected None for completion.choices[0].delta. Please report this issue."
380
+ )
381
+
382
+ delta = completion.choices[0].delta
383
+ new_delta = chunk.choices[0].delta
384
+ if "content_list" not in delta or "content_list" not in new_delta:
385
+ raise ValueError(
386
+ "Expected content_list to be in completion.choices[0].delta. Please report this issue."
387
+ )
388
+
389
+ content_list = delta["content_list"]
390
+ new_content_list = new_delta["content_list"]
391
+ if not isinstance(content_list, list) or not isinstance(new_content_list, list):
392
+ raise ValueError(
393
+ f"Expected content_list to be a list, got {type(new_content_list)}"
394
+ )
395
+
396
+ if new_delta["type"] == "tool_use":
397
+ # Presence of "tool_use_id" indicates a new tool request; otherwise, we're
398
+ # expecting input parameters
399
+ if "tool_use_id" in new_delta:
400
+ del new_delta["text"] # why is this here :eye-roll:?
401
+ content_list.append(new_delta)
402
+ elif "input" in new_delta:
403
+ # find most recent content with type: "tool_use" and append to that
404
+ for i in range(len(content_list) - 1, -1, -1):
405
+ if "tool_use_id" in content_list[i]:
406
+ content_list[i]["input"] = content_list[i].get("input", "")
407
+ content_list[i]["input"] += new_delta["input"]
408
+ break
409
+ else:
410
+ raise ValueError(
411
+ f"Unexpected tool_use delta: {new_delta}. Please report this issue."
412
+ )
413
+ elif new_delta["type"] == "text":
414
+ text = new_delta["text"]
415
+ # find most recent content with type: "text" and append to that
416
+ for i in range(len(content_list) - 1, -1, -1):
417
+ if content_list[i].get("type") == "text":
418
+ content_list[i]["text"] += text
419
+ break
420
+ else:
421
+ # if we don't find it, just append to the end
422
+ # this shouldn't happen, but just in case
423
+ content_list.append({"type": "text", "text": text})
424
+ else:
425
+ raise ValueError(
426
+ f"Unexpected streaming delta type: {new_delta['type']}. Please report this issue."
427
+ )
428
+
429
+ completion.choices[0].delta["content_list"] = content_list
430
+
431
+ return completion
297
432
 
298
433
  def stream_turn(self, completion, has_data_model) -> Turn:
434
+ import snowflake.core.cortex.inference_service._generated.models as models
435
+
436
+ completion_dict = completion.model_dump()
437
+ delta = completion_dict["choices"][0].pop("delta")
438
+ completion_dict["choices"][0]["message"] = delta
439
+ completion = models.NonStreamingCompleteResponse.model_construct(
440
+ **completion_dict
441
+ )
299
442
  return self._as_turn(completion, has_data_model)
300
443
 
301
444
  def value_turn(self, completion, has_data_model) -> Turn:
@@ -321,24 +464,207 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
321
464
  "Snowflake does not currently support token counting."
322
465
  )
323
466
 
324
- def _as_prompt_input(self, turns: list[Turn]) -> list["ConversationMessage"]:
325
- res: list["ConversationMessage"] = []
467
+ def _as_request_messages(self, turns: list[Turn]):
468
+ from snowflake.core.cortex.inference_service import CompleteRequestMessagesInner
469
+
470
+ res: list[CompleteRequestMessagesInner] = []
326
471
  for turn in turns:
327
- res.append(
328
- {
329
- "role": turn.role,
330
- "content": str(turn),
331
- }
472
+ req = CompleteRequestMessagesInner(
473
+ role=turn.role,
474
+ content=turn.text,
332
475
  )
476
+ for x in turn.contents:
477
+ if isinstance(x, ContentToolRequest):
478
+ req.content_list = req.content_list or []
479
+ req.content_list.append(
480
+ {
481
+ "type": "tool_use",
482
+ "tool_use": {
483
+ "tool_use_id": x.id,
484
+ "name": x.name,
485
+ "input": x.arguments,
486
+ },
487
+ }
488
+ )
489
+ elif isinstance(x, ContentToolResult):
490
+ # Snowflake does like empty content
491
+ req.content = req.content or "[tool_result]"
492
+ req.content_list = req.content_list or []
493
+ req.content_list.append(
494
+ {
495
+ "type": "tool_results",
496
+ "tool_results": {
497
+ "tool_use_id": x.id,
498
+ "name": x.name,
499
+ "content": [
500
+ {"type": "text", "text": x.get_model_value()}
501
+ ],
502
+ },
503
+ }
504
+ )
505
+ elif isinstance(x, ContentJson):
506
+ req.content = req.content or "<structured data/>"
507
+
508
+ res.append(req)
333
509
  return res
334
510
 
335
- def _as_turn(self, completion, has_data_model) -> Turn:
336
- completion = cast(str, completion)
511
+ def _as_turn(self, completion: "Completion", has_data_model: bool) -> Turn:
512
+ import snowflake.core.cortex.inference_service._generated.models as models
513
+
514
+ if not completion.choices:
515
+ return Turn("assistant", [])
516
+
517
+ choice = completion.choices[0]
518
+ if isinstance(choice, dict):
519
+ choice = models.NonStreamingCompleteResponseChoicesInner.from_dict(choice)
520
+
521
+ message = choice.message
522
+ if message is None:
523
+ return Turn("assistant", [])
524
+
525
+ contents: list[Content] = []
526
+ content_list = message.content_list or []
527
+ for content in content_list:
528
+ if "text" in content:
529
+ if has_data_model:
530
+ data = orjson.loads(content["text"])
531
+ contents.append(ContentJson(value=data))
532
+ else:
533
+ contents.append(ContentText(text=content["text"]))
534
+ elif "tool_use_id" in content:
535
+ params = content.get("input", "{}")
536
+ try:
537
+ params = orjson.loads(params)
538
+ except orjson.JSONDecodeError:
539
+ raise ValueError(
540
+ f"Failed to parse tool_use input: {params}. Please report this issue."
541
+ )
542
+ contents.append(
543
+ ContentToolRequest(
544
+ name=content["name"],
545
+ id=content["tool_use_id"],
546
+ arguments=params,
547
+ )
548
+ )
549
+
550
+ usage = completion.usage
551
+ if usage is None:
552
+ tokens = (0, 0)
553
+ else:
554
+ tokens = (usage.prompt_tokens or 0, usage.completion_tokens or 0)
555
+
556
+ tokens_log(self, tokens)
557
+
558
+ return Turn(
559
+ "assistant",
560
+ contents,
561
+ tokens=tokens,
562
+ # TODO: no finish_reason in Snowflake?
563
+ # finish_reason=completion.choices[0].finish_reason,
564
+ completion=completion,
565
+ )
566
+
567
+ # N.B. this is currently the best documentation I can find for how tool calling works
568
+ # https://quickstarts.snowflake.com/guide/getting-started-with-tool-use-on-cortex-and-anthropic-claude/index.html#5
569
+ def _as_snowflake_tool(self, tool: Tool):
570
+ import snowflake.core.cortex.inference_service._generated.models as models
571
+
572
+ func = tool.schema["function"]
573
+ params = func.get("parameters", {})
574
+
575
+ props = params.get("properties", {})
576
+ if not isinstance(props, dict):
577
+ raise ValueError(
578
+ f"Tool function parameters must be a dictionary, got {type(props)}"
579
+ )
580
+
581
+ required = params.get("required", [])
582
+ if not isinstance(required, list):
583
+ raise ValueError(
584
+ f"Tool function required parameters must be a list, got {type(required)}"
585
+ )
586
+
587
+ input_schema = models.ToolToolSpecInputSchema(
588
+ type="object",
589
+ properties=props or None,
590
+ required=required or None,
591
+ )
592
+
593
+ spec = models.ToolToolSpec(
594
+ type="generic",
595
+ name=func["name"],
596
+ description=func.get("description", ""),
597
+ input_schema=input_schema,
598
+ )
337
599
 
338
- if has_data_model:
339
- data = json.loads(completion)
340
- contents = [ContentJson(value=data)]
600
+ return models.Tool(tool_spec=spec)
601
+
602
+
603
+ # Yield parsed event data from the Snowflake SSEClient
604
+ # (this is only needed for the streaming case).
605
+ def generate_event_data(events: Generator["Event", None, None]):
606
+ for x in events:
607
+ if x.data:
608
+ yield parse_event_data(x.data, stream=True)
609
+
610
+
611
+ # Same thing for the async case.
612
+ async def generate_event_data_async(events: Generator["Event", None, None]):
613
+ for x in events:
614
+ if x.data:
615
+ yield parse_event_data(x.data, stream=True)
616
+
617
+
618
+ @overload
619
+ def parse_event_data(
620
+ data: str, stream: Literal[True]
621
+ ) -> "models.StreamingCompleteResponseDataEvent": ...
622
+
623
+
624
+ @overload
625
+ def parse_event_data(
626
+ data: str, stream: Literal[False]
627
+ ) -> "models.NonStreamingCompleteResponse": ...
628
+
629
+
630
+ def parse_event_data(
631
+ data: str, stream: bool
632
+ ) -> "models.NonStreamingCompleteResponse | models.StreamingCompleteResponseDataEvent":
633
+ "Parse the (JSON) event data from Snowflake using the relevant pydantic model."
634
+ import snowflake.core.cortex.inference_service._generated.models as models
635
+
636
+ try:
637
+ if stream:
638
+ return models.StreamingCompleteResponseDataEvent.from_json(data)
341
639
  else:
342
- contents = [ContentText(text=completion)]
640
+ return models.NonStreamingCompleteResponse.from_json(data)
641
+ except Exception:
642
+ raise ValueError(
643
+ f"Failed to parse Snowflake event data: {data}. "
644
+ "Please report this error here: https://github.com/posit-dev/chatlas/issues/new"
645
+ )
646
+
647
+
648
+ # At the time writing, .events() flat out errors in the stream=False case since
649
+ # the Content-Type is set to application/json;charset=utf-8, and SSEClient
650
+ # doesn't know how to handle that.
651
+ # https://github.com/snowflakedb/snowflake-ml-python/blob/6910e96/snowflake/cortex/_sse_client.py#L69
652
+ #
653
+ # So, do some janky stuff here to get the data out of the response.
654
+ #
655
+ # If and when snowflake fixes this, we can remove the try/except block.
656
+ def parse_request_object(
657
+ client: "SSEClient",
658
+ ) -> "Optional[models.NonStreamingCompleteResponse]":
659
+ try:
660
+ import urllib3
661
+
662
+ if isinstance(client._event_source, urllib3.response.HTTPResponse):
663
+ return parse_event_data(
664
+ client._event_source.data.decode("utf-8"),
665
+ stream=False,
666
+ )
667
+ except Exception:
668
+ pass
343
669
 
344
- return Turn("assistant", contents)
670
+ return None
chatlas/_tools.py CHANGED
@@ -8,7 +8,10 @@ from pydantic import BaseModel, Field, create_model
8
8
 
9
9
  from . import _utils
10
10
 
11
- __all__ = ("Tool",)
11
+ __all__ = (
12
+ "Tool",
13
+ "ToolRejectError",
14
+ )
12
15
 
13
16
  if TYPE_CHECKING:
14
17
  from openai.types.chat import ChatCompletionToolParam
@@ -47,6 +50,61 @@ class Tool:
47
50
  self.name = self.schema["function"]["name"]
48
51
 
49
52
 
53
+ class ToolRejectError(Exception):
54
+ """
55
+ Error to represent a tool call being rejected.
56
+
57
+ This error is meant to be raised when an end user has chosen to deny a tool
58
+ call. It can be raised in a tool function or in a `.on_tool_request()`
59
+ callback registered via a :class:`~chatlas.Chat`. When used in the callback,
60
+ the tool call is rejected before the tool function is invoked.
61
+
62
+ Parameters
63
+ ----------
64
+ reason
65
+ A string describing the reason for rejecting the tool call. This will be
66
+ included in the error message passed to the LLM. In addition to the
67
+ reason, the error message will also include "Tool call rejected." to
68
+ indicate that the tool call was not processed.
69
+
70
+ Raises
71
+ -------
72
+ ToolRejectError
73
+ An error with a message informing the LLM that the tool call was
74
+ rejected (and the reason why).
75
+
76
+ Examples
77
+ --------
78
+ >>> import os
79
+ >>> import chatlas as ctl
80
+ >>>
81
+ >>> chat = ctl.ChatOpenAI()
82
+ >>>
83
+ >>> def list_files():
84
+ ... "List files in the user's current directory"
85
+ ... while True:
86
+ ... allow = input(
87
+ ... "Would you like to allow access to your current directory? (yes/no): "
88
+ ... )
89
+ ... if allow.lower() == "yes":
90
+ ... return os.listdir(".")
91
+ ... elif allow.lower() == "no":
92
+ ... raise ctl.ToolRejectError(
93
+ ... "The user has chosen to disallow the tool call."
94
+ ... )
95
+ ... else:
96
+ ... print("Please answer with 'yes' or 'no'.")
97
+ >>>
98
+ >>> chat.register_tool(list_files)
99
+ >>> chat.chat("What files are available in my current directory?")
100
+ """
101
+
102
+ def __init__(self, reason: str = "The user has chosen to disallow the tool call."):
103
+ message = f"Tool call rejected. {reason}"
104
+ super().__init__(message)
105
+ self.message = message
106
+
107
+
50
108
  def func_to_schema(
51
109
  func: Callable[..., Any] | Callable[..., Awaitable[Any]],
52
110
  model: Optional[type[BaseModel]] = None,
chatlas/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.1'
21
- __version_tuple__ = version_tuple = (0, 7, 1)
20
+ __version__ = version = '0.8.1'
21
+ __version_tuple__ = version_tuple = (0, 8, 1)