chatlas 0.7.1__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chatlas might be problematic. Click here for more details.
- chatlas/__init__.py +2 -1
- chatlas/_anthropic.py +1 -4
- chatlas/_callbacks.py +56 -0
- chatlas/_chat.py +182 -84
- chatlas/_content.py +6 -0
- chatlas/_databricks.py +1 -1
- chatlas/_logging.py +29 -5
- chatlas/_snowflake.py +398 -72
- chatlas/_tools.py +59 -1
- chatlas/_version.py +2 -2
- chatlas/types/anthropic/_submit.py +7 -0
- chatlas/types/openai/_submit.py +1 -0
- {chatlas-0.7.1.dist-info → chatlas-0.8.1.dist-info}/METADATA +2 -2
- {chatlas-0.7.1.dist-info → chatlas-0.8.1.dist-info}/RECORD +16 -17
- chatlas/types/snowflake/__init__.py +0 -8
- chatlas/types/snowflake/_submit.py +0 -24
- {chatlas-0.7.1.dist-info → chatlas-0.8.1.dist-info}/WHEEL +0 -0
- {chatlas-0.7.1.dist-info → chatlas-0.8.1.dist-info}/licenses/LICENSE +0 -0
chatlas/_snowflake.py
CHANGED
|
@@ -1,32 +1,60 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
1
|
+
from typing import (
|
|
2
|
+
TYPE_CHECKING,
|
|
3
|
+
Generator,
|
|
4
|
+
Literal,
|
|
5
|
+
Optional,
|
|
6
|
+
TypedDict,
|
|
7
|
+
Union,
|
|
8
|
+
overload,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
import orjson
|
|
5
12
|
from pydantic import BaseModel
|
|
6
13
|
|
|
7
14
|
from ._chat import Chat
|
|
8
|
-
from ._content import
|
|
15
|
+
from ._content import (
|
|
16
|
+
Content,
|
|
17
|
+
ContentJson,
|
|
18
|
+
ContentText,
|
|
19
|
+
ContentToolRequest,
|
|
20
|
+
ContentToolResult,
|
|
21
|
+
)
|
|
9
22
|
from ._logging import log_model_default
|
|
10
23
|
from ._provider import Provider
|
|
24
|
+
from ._tokens import tokens_log
|
|
11
25
|
from ._tools import Tool, basemodel_to_param_schema
|
|
12
26
|
from ._turn import Turn, normalize_turns
|
|
13
|
-
from ._utils import drop_none
|
|
27
|
+
from ._utils import drop_none
|
|
14
28
|
|
|
15
29
|
if TYPE_CHECKING:
|
|
16
|
-
|
|
30
|
+
import snowflake.core.cortex.inference_service._generated.models as models
|
|
31
|
+
from snowflake.core.rest import Event, SSEClient
|
|
32
|
+
|
|
33
|
+
Completion = models.NonStreamingCompleteResponse
|
|
34
|
+
CompletionChunk = models.StreamingCompleteResponseDataEvent
|
|
35
|
+
|
|
36
|
+
# Manually constructed TypedDict equivalent of models.CompleteRequest
|
|
37
|
+
class CompleteRequest(TypedDict, total=False):
|
|
38
|
+
"""
|
|
39
|
+
CompleteRequest parameters for Snowflake Cortex LLMs.
|
|
40
|
+
|
|
41
|
+
See `snowflake.core.cortex.inference_service.CompleteRequest` for more details.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
temperature: Union[float, int]
|
|
45
|
+
"""Temperature controls the amount of randomness used in response generation. A higher temperature corresponds to more randomness."""
|
|
17
46
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
CompletionChunk = str
|
|
47
|
+
top_p: Union[float, int]
|
|
48
|
+
"""Threshold probability for nucleus sampling. A higher top-p value increases the diversity of tokens that the model considers, while a lower value results in more predictable output."""
|
|
21
49
|
|
|
22
|
-
|
|
50
|
+
max_tokens: int
|
|
51
|
+
"""The maximum number of output tokens to produce. The default value is model-dependent."""
|
|
23
52
|
|
|
53
|
+
guardrails: models.GuardrailsConfig
|
|
54
|
+
"""Controls whether guardrails are enabled."""
|
|
24
55
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class ConversationMessage(TypedDict):
|
|
28
|
-
role: str
|
|
29
|
-
content: str
|
|
56
|
+
tool_choice: models.ToolChoice
|
|
57
|
+
"""Determines how tools are selected."""
|
|
30
58
|
|
|
31
59
|
|
|
32
60
|
def ChatSnowflake(
|
|
@@ -41,7 +69,7 @@ def ChatSnowflake(
|
|
|
41
69
|
private_key_file: Optional[str] = None,
|
|
42
70
|
private_key_file_pwd: Optional[str] = None,
|
|
43
71
|
kwargs: Optional[dict[str, "str | int"]] = None,
|
|
44
|
-
) -> Chat["
|
|
72
|
+
) -> Chat["CompleteRequest", "Completion"]:
|
|
45
73
|
"""
|
|
46
74
|
Chat with a Snowflake Cortex LLM
|
|
47
75
|
|
|
@@ -116,7 +144,7 @@ def ChatSnowflake(
|
|
|
116
144
|
"""
|
|
117
145
|
|
|
118
146
|
if model is None:
|
|
119
|
-
model = log_model_default("
|
|
147
|
+
model = log_model_default("claude-3-7-sonnet")
|
|
120
148
|
|
|
121
149
|
return Chat(
|
|
122
150
|
provider=SnowflakeProvider(
|
|
@@ -150,6 +178,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
150
178
|
kwargs: Optional[dict[str, "str | int"]],
|
|
151
179
|
):
|
|
152
180
|
try:
|
|
181
|
+
from snowflake.core import Root
|
|
153
182
|
from snowflake.snowpark import Session
|
|
154
183
|
except ImportError:
|
|
155
184
|
raise ImportError(
|
|
@@ -170,7 +199,9 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
170
199
|
)
|
|
171
200
|
|
|
172
201
|
self._model = model
|
|
173
|
-
|
|
202
|
+
|
|
203
|
+
session = Session.builder.configs(configs).create()
|
|
204
|
+
self._cortex_service = Root(session).cortex_inference_service
|
|
174
205
|
|
|
175
206
|
@overload
|
|
176
207
|
def chat_perform(
|
|
@@ -180,7 +211,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
180
211
|
turns: list[Turn],
|
|
181
212
|
tools: dict[str, Tool],
|
|
182
213
|
data_model: Optional[type[BaseModel]] = None,
|
|
183
|
-
kwargs: Optional["
|
|
214
|
+
kwargs: Optional["CompleteRequest"] = None,
|
|
184
215
|
): ...
|
|
185
216
|
|
|
186
217
|
@overload
|
|
@@ -191,7 +222,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
191
222
|
turns: list[Turn],
|
|
192
223
|
tools: dict[str, Tool],
|
|
193
224
|
data_model: Optional[type[BaseModel]] = None,
|
|
194
|
-
kwargs: Optional["
|
|
225
|
+
kwargs: Optional["CompleteRequest"] = None,
|
|
195
226
|
): ...
|
|
196
227
|
|
|
197
228
|
def chat_perform(
|
|
@@ -201,12 +232,25 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
201
232
|
turns: list[Turn],
|
|
202
233
|
tools: dict[str, Tool],
|
|
203
234
|
data_model: Optional[type[BaseModel]] = None,
|
|
204
|
-
kwargs: Optional["
|
|
235
|
+
kwargs: Optional["CompleteRequest"] = None,
|
|
205
236
|
):
|
|
206
|
-
|
|
237
|
+
req = self._complete_request(stream, turns, tools, data_model, kwargs)
|
|
238
|
+
client = self._cortex_service.complete(req)
|
|
239
|
+
|
|
240
|
+
try:
|
|
241
|
+
events = client.events()
|
|
242
|
+
except Exception as e:
|
|
243
|
+
data = parse_request_object(client)
|
|
244
|
+
if data is None:
|
|
245
|
+
raise e
|
|
246
|
+
return data
|
|
207
247
|
|
|
208
|
-
|
|
209
|
-
|
|
248
|
+
if stream:
|
|
249
|
+
return generate_event_data(events)
|
|
250
|
+
|
|
251
|
+
for evt in events:
|
|
252
|
+
if evt.data:
|
|
253
|
+
return parse_event_data(evt.data, stream=False)
|
|
210
254
|
|
|
211
255
|
@overload
|
|
212
256
|
async def chat_perform_async(
|
|
@@ -216,7 +260,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
216
260
|
turns: list[Turn],
|
|
217
261
|
tools: dict[str, Tool],
|
|
218
262
|
data_model: Optional[type[BaseModel]] = None,
|
|
219
|
-
kwargs: Optional["
|
|
263
|
+
kwargs: Optional["CompleteRequest"] = None,
|
|
220
264
|
): ...
|
|
221
265
|
|
|
222
266
|
@overload
|
|
@@ -227,7 +271,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
227
271
|
turns: list[Turn],
|
|
228
272
|
tools: dict[str, Tool],
|
|
229
273
|
data_model: Optional[type[BaseModel]] = None,
|
|
230
|
-
kwargs: Optional["
|
|
274
|
+
kwargs: Optional["CompleteRequest"] = None,
|
|
231
275
|
): ...
|
|
232
276
|
|
|
233
277
|
async def chat_perform_async(
|
|
@@ -237,65 +281,164 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
237
281
|
turns: list[Turn],
|
|
238
282
|
tools: dict[str, Tool],
|
|
239
283
|
data_model: Optional[type[BaseModel]] = None,
|
|
240
|
-
kwargs: Optional["
|
|
284
|
+
kwargs: Optional["CompleteRequest"] = None,
|
|
241
285
|
):
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
286
|
+
req = self._complete_request(stream, turns, tools, data_model, kwargs)
|
|
287
|
+
res = self._cortex_service.complete_async(req)
|
|
288
|
+
# TODO: is there a way to get the SSEClient result without blocking?
|
|
289
|
+
client = res.result()
|
|
245
290
|
|
|
246
|
-
|
|
247
|
-
|
|
291
|
+
try:
|
|
292
|
+
events = client.events()
|
|
293
|
+
except Exception as e:
|
|
294
|
+
data = parse_request_object(client)
|
|
295
|
+
if data is None:
|
|
296
|
+
raise e
|
|
297
|
+
return data
|
|
248
298
|
|
|
249
|
-
# When streaming, res is an iterable of strings, but Chat() wants an async iterable
|
|
250
299
|
if stream:
|
|
251
|
-
|
|
300
|
+
return generate_event_data_async(events)
|
|
252
301
|
|
|
253
|
-
|
|
302
|
+
for evt in events:
|
|
303
|
+
if evt.data:
|
|
304
|
+
return parse_event_data(evt.data, stream=False)
|
|
254
305
|
|
|
255
|
-
def
|
|
306
|
+
def _complete_request(
|
|
256
307
|
self,
|
|
257
308
|
stream: bool,
|
|
258
309
|
turns: list[Turn],
|
|
259
310
|
tools: dict[str, Tool],
|
|
260
311
|
data_model: Optional[type[BaseModel]] = None,
|
|
261
|
-
kwargs: Optional["
|
|
312
|
+
kwargs: Optional["CompleteRequest"] = None,
|
|
262
313
|
):
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
# TODO: get tools working
|
|
314
|
+
from snowflake.core.cortex.inference_service import CompleteRequest
|
|
315
|
+
|
|
316
|
+
req = CompleteRequest(
|
|
317
|
+
model=self._model,
|
|
318
|
+
messages=self._as_request_messages(turns),
|
|
319
|
+
stream=stream,
|
|
320
|
+
)
|
|
321
|
+
|
|
272
322
|
if tools:
|
|
273
|
-
|
|
323
|
+
req.tools = req.tools or []
|
|
324
|
+
snow_tools = [self._as_snowflake_tool(tool) for tool in tools.values()]
|
|
325
|
+
req.tools.extend(snow_tools)
|
|
274
326
|
|
|
275
327
|
if data_model is not None:
|
|
328
|
+
import snowflake.core.cortex.inference_service._generated.models as models
|
|
329
|
+
|
|
276
330
|
params = basemodel_to_param_schema(data_model)
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
"schema": {
|
|
331
|
+
req.response_format = models.CompleteRequestResponseFormat(
|
|
332
|
+
type="json",
|
|
333
|
+
schema={
|
|
281
334
|
"type": "object",
|
|
282
335
|
"properties": params["properties"],
|
|
283
336
|
"required": params["required"],
|
|
284
337
|
},
|
|
285
|
-
|
|
286
|
-
kwargs_full["options"] = opts
|
|
338
|
+
)
|
|
287
339
|
|
|
288
|
-
|
|
340
|
+
if kwargs:
|
|
341
|
+
for k, v in kwargs.items():
|
|
342
|
+
if hasattr(req, k):
|
|
343
|
+
setattr(req, k, v)
|
|
344
|
+
else:
|
|
345
|
+
raise ValueError(
|
|
346
|
+
f"Unknown parameter {k} for Snowflake CompleteRequest. "
|
|
347
|
+
"Please check the Snowflake documentation for valid parameters."
|
|
348
|
+
)
|
|
289
349
|
|
|
290
|
-
|
|
291
|
-
return chunk
|
|
350
|
+
return req
|
|
292
351
|
|
|
352
|
+
def stream_text(self, chunk):
|
|
353
|
+
if not chunk.choices:
|
|
354
|
+
return None
|
|
355
|
+
delta = chunk.choices[0].delta
|
|
356
|
+
if delta is None or "content" not in delta:
|
|
357
|
+
return None
|
|
358
|
+
return delta["content"]
|
|
359
|
+
|
|
360
|
+
# Snowflake sort-of follows OpenAI/Anthropic streaming formats except they
|
|
361
|
+
# don't have the critical "index" field in the delta that the merge logic
|
|
362
|
+
# depends on (i.e., OpenAI), or official start/stop events (i.e.,
|
|
363
|
+
# Anthropic). So we have to do some janky merging here.
|
|
364
|
+
#
|
|
365
|
+
# This was done in a panic to get working asap, so don't judge :) I wouldn't
|
|
366
|
+
# be surprised if Snowflake realizes how bad this streaming format is and
|
|
367
|
+
# changes it in the future (thus probably breaking this code :( ).
|
|
293
368
|
def stream_merge_chunks(self, completion, chunk):
|
|
294
369
|
if completion is None:
|
|
295
370
|
return chunk
|
|
296
|
-
|
|
371
|
+
|
|
372
|
+
if completion.choices is None or chunk.choices is None:
|
|
373
|
+
raise ValueError(
|
|
374
|
+
"Unexpected None for completion.choices. Please report this issue."
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
if completion.choices[0].delta is None or chunk.choices[0].delta is None:
|
|
378
|
+
raise ValueError(
|
|
379
|
+
"Unexpected None for completion.choices[0].delta. Please report this issue."
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
delta = completion.choices[0].delta
|
|
383
|
+
new_delta = chunk.choices[0].delta
|
|
384
|
+
if "content_list" not in delta or "content_list" not in new_delta:
|
|
385
|
+
raise ValueError(
|
|
386
|
+
"Expected content_list to be in completion.choices[0].delta. Please report this issue."
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
content_list = delta["content_list"]
|
|
390
|
+
new_content_list = new_delta["content_list"]
|
|
391
|
+
if not isinstance(content_list, list) or not isinstance(new_content_list, list):
|
|
392
|
+
raise ValueError(
|
|
393
|
+
f"Expected content_list to be a list, got {type(new_content_list)}"
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
if new_delta["type"] == "tool_use":
|
|
397
|
+
# Presence of "tool_use_id" indicates a new tool request; otherwise, we're
|
|
398
|
+
# expecting input parameters
|
|
399
|
+
if "tool_use_id" in new_delta:
|
|
400
|
+
del new_delta["text"] # why is this here :eye-roll:?
|
|
401
|
+
content_list.append(new_delta)
|
|
402
|
+
elif "input" in new_delta:
|
|
403
|
+
# find most recent content with type: "tool_use" and append to that
|
|
404
|
+
for i in range(len(content_list) - 1, -1, -1):
|
|
405
|
+
if "tool_use_id" in content_list[i]:
|
|
406
|
+
content_list[i]["input"] = content_list[i].get("input", "")
|
|
407
|
+
content_list[i]["input"] += new_delta["input"]
|
|
408
|
+
break
|
|
409
|
+
else:
|
|
410
|
+
raise ValueError(
|
|
411
|
+
f"Unexpected tool_use delta: {new_delta}. Please report this issue."
|
|
412
|
+
)
|
|
413
|
+
elif new_delta["type"] == "text":
|
|
414
|
+
text = new_delta["text"]
|
|
415
|
+
# find most recent content with type: "text" and append to that
|
|
416
|
+
for i in range(len(content_list) - 1, -1, -1):
|
|
417
|
+
if content_list[i].get("type") == "text":
|
|
418
|
+
content_list[i]["text"] += text
|
|
419
|
+
break
|
|
420
|
+
else:
|
|
421
|
+
# if we don't find it, just append to the end
|
|
422
|
+
# this shouldn't happen, but just in case
|
|
423
|
+
content_list.append({"type": "text", "text": text})
|
|
424
|
+
else:
|
|
425
|
+
raise ValueError(
|
|
426
|
+
f"Unexpected streaming delta type: {new_delta['type']}. Please report this issue."
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
completion.choices[0].delta["content_list"] = content_list
|
|
430
|
+
|
|
431
|
+
return completion
|
|
297
432
|
|
|
298
433
|
def stream_turn(self, completion, has_data_model) -> Turn:
|
|
434
|
+
import snowflake.core.cortex.inference_service._generated.models as models
|
|
435
|
+
|
|
436
|
+
completion_dict = completion.model_dump()
|
|
437
|
+
delta = completion_dict["choices"][0].pop("delta")
|
|
438
|
+
completion_dict["choices"][0]["message"] = delta
|
|
439
|
+
completion = models.NonStreamingCompleteResponse.model_construct(
|
|
440
|
+
**completion_dict
|
|
441
|
+
)
|
|
299
442
|
return self._as_turn(completion, has_data_model)
|
|
300
443
|
|
|
301
444
|
def value_turn(self, completion, has_data_model) -> Turn:
|
|
@@ -321,24 +464,207 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
321
464
|
"Snowflake does not currently support token counting."
|
|
322
465
|
)
|
|
323
466
|
|
|
324
|
-
def
|
|
325
|
-
|
|
467
|
+
def _as_request_messages(self, turns: list[Turn]):
|
|
468
|
+
from snowflake.core.cortex.inference_service import CompleteRequestMessagesInner
|
|
469
|
+
|
|
470
|
+
res: list[CompleteRequestMessagesInner] = []
|
|
326
471
|
for turn in turns:
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
"content": str(turn),
|
|
331
|
-
}
|
|
472
|
+
req = CompleteRequestMessagesInner(
|
|
473
|
+
role=turn.role,
|
|
474
|
+
content=turn.text,
|
|
332
475
|
)
|
|
476
|
+
for x in turn.contents:
|
|
477
|
+
if isinstance(x, ContentToolRequest):
|
|
478
|
+
req.content_list = req.content_list or []
|
|
479
|
+
req.content_list.append(
|
|
480
|
+
{
|
|
481
|
+
"type": "tool_use",
|
|
482
|
+
"tool_use": {
|
|
483
|
+
"tool_use_id": x.id,
|
|
484
|
+
"name": x.name,
|
|
485
|
+
"input": x.arguments,
|
|
486
|
+
},
|
|
487
|
+
}
|
|
488
|
+
)
|
|
489
|
+
elif isinstance(x, ContentToolResult):
|
|
490
|
+
# Snowflake does like empty content
|
|
491
|
+
req.content = req.content or "[tool_result]"
|
|
492
|
+
req.content_list = req.content_list or []
|
|
493
|
+
req.content_list.append(
|
|
494
|
+
{
|
|
495
|
+
"type": "tool_results",
|
|
496
|
+
"tool_results": {
|
|
497
|
+
"tool_use_id": x.id,
|
|
498
|
+
"name": x.name,
|
|
499
|
+
"content": [
|
|
500
|
+
{"type": "text", "text": x.get_model_value()}
|
|
501
|
+
],
|
|
502
|
+
},
|
|
503
|
+
}
|
|
504
|
+
)
|
|
505
|
+
elif isinstance(x, ContentJson):
|
|
506
|
+
req.content = req.content or "<structured data/>"
|
|
507
|
+
|
|
508
|
+
res.append(req)
|
|
333
509
|
return res
|
|
334
510
|
|
|
335
|
-
def _as_turn(self, completion, has_data_model) -> Turn:
|
|
336
|
-
|
|
511
|
+
def _as_turn(self, completion: "Completion", has_data_model: bool) -> Turn:
|
|
512
|
+
import snowflake.core.cortex.inference_service._generated.models as models
|
|
513
|
+
|
|
514
|
+
if not completion.choices:
|
|
515
|
+
return Turn("assistant", [])
|
|
516
|
+
|
|
517
|
+
choice = completion.choices[0]
|
|
518
|
+
if isinstance(choice, dict):
|
|
519
|
+
choice = models.NonStreamingCompleteResponseChoicesInner.from_dict(choice)
|
|
520
|
+
|
|
521
|
+
message = choice.message
|
|
522
|
+
if message is None:
|
|
523
|
+
return Turn("assistant", [])
|
|
524
|
+
|
|
525
|
+
contents: list[Content] = []
|
|
526
|
+
content_list = message.content_list or []
|
|
527
|
+
for content in content_list:
|
|
528
|
+
if "text" in content:
|
|
529
|
+
if has_data_model:
|
|
530
|
+
data = orjson.loads(content["text"])
|
|
531
|
+
contents.append(ContentJson(value=data))
|
|
532
|
+
else:
|
|
533
|
+
contents.append(ContentText(text=content["text"]))
|
|
534
|
+
elif "tool_use_id" in content:
|
|
535
|
+
params = content.get("input", "{}")
|
|
536
|
+
try:
|
|
537
|
+
params = orjson.loads(params)
|
|
538
|
+
except orjson.JSONDecodeError:
|
|
539
|
+
raise ValueError(
|
|
540
|
+
f"Failed to parse tool_use input: {params}. Please report this issue."
|
|
541
|
+
)
|
|
542
|
+
contents.append(
|
|
543
|
+
ContentToolRequest(
|
|
544
|
+
name=content["name"],
|
|
545
|
+
id=content["tool_use_id"],
|
|
546
|
+
arguments=params,
|
|
547
|
+
)
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
usage = completion.usage
|
|
551
|
+
if usage is None:
|
|
552
|
+
tokens = (0, 0)
|
|
553
|
+
else:
|
|
554
|
+
tokens = (usage.prompt_tokens or 0, usage.completion_tokens or 0)
|
|
555
|
+
|
|
556
|
+
tokens_log(self, tokens)
|
|
557
|
+
|
|
558
|
+
return Turn(
|
|
559
|
+
"assistant",
|
|
560
|
+
contents,
|
|
561
|
+
tokens=tokens,
|
|
562
|
+
# TODO: no finish_reason in Snowflake?
|
|
563
|
+
# finish_reason=completion.choices[0].finish_reason,
|
|
564
|
+
completion=completion,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
# N.B. this is currently the best documentation I can find for how tool calling works
|
|
568
|
+
# https://quickstarts.snowflake.com/guide/getting-started-with-tool-use-on-cortex-and-anthropic-claude/index.html#5
|
|
569
|
+
def _as_snowflake_tool(self, tool: Tool):
|
|
570
|
+
import snowflake.core.cortex.inference_service._generated.models as models
|
|
571
|
+
|
|
572
|
+
func = tool.schema["function"]
|
|
573
|
+
params = func.get("parameters", {})
|
|
574
|
+
|
|
575
|
+
props = params.get("properties", {})
|
|
576
|
+
if not isinstance(props, dict):
|
|
577
|
+
raise ValueError(
|
|
578
|
+
f"Tool function parameters must be a dictionary, got {type(props)}"
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
required = params.get("required", [])
|
|
582
|
+
if not isinstance(required, list):
|
|
583
|
+
raise ValueError(
|
|
584
|
+
f"Tool function required parameters must be a list, got {type(required)}"
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
input_schema = models.ToolToolSpecInputSchema(
|
|
588
|
+
type="object",
|
|
589
|
+
properties=props or None,
|
|
590
|
+
required=required or None,
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
spec = models.ToolToolSpec(
|
|
594
|
+
type="generic",
|
|
595
|
+
name=func["name"],
|
|
596
|
+
description=func.get("description", ""),
|
|
597
|
+
input_schema=input_schema,
|
|
598
|
+
)
|
|
337
599
|
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
600
|
+
return models.Tool(tool_spec=spec)
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
# Yield parsed event data from the Snowflake SSEClient
|
|
604
|
+
# (this is only needed for the streaming case).
|
|
605
|
+
def generate_event_data(events: Generator["Event", None, None]):
|
|
606
|
+
for x in events:
|
|
607
|
+
if x.data:
|
|
608
|
+
yield parse_event_data(x.data, stream=True)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
# Same thing for the async case.
|
|
612
|
+
async def generate_event_data_async(events: Generator["Event", None, None]):
|
|
613
|
+
for x in events:
|
|
614
|
+
if x.data:
|
|
615
|
+
yield parse_event_data(x.data, stream=True)
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
@overload
|
|
619
|
+
def parse_event_data(
|
|
620
|
+
data: str, stream: Literal[True]
|
|
621
|
+
) -> "models.StreamingCompleteResponseDataEvent": ...
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
@overload
|
|
625
|
+
def parse_event_data(
|
|
626
|
+
data: str, stream: Literal[False]
|
|
627
|
+
) -> "models.NonStreamingCompleteResponse": ...
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
def parse_event_data(
|
|
631
|
+
data: str, stream: bool
|
|
632
|
+
) -> "models.NonStreamingCompleteResponse | models.StreamingCompleteResponseDataEvent":
|
|
633
|
+
"Parse the (JSON) event data from Snowflake using the relevant pydantic model."
|
|
634
|
+
import snowflake.core.cortex.inference_service._generated.models as models
|
|
635
|
+
|
|
636
|
+
try:
|
|
637
|
+
if stream:
|
|
638
|
+
return models.StreamingCompleteResponseDataEvent.from_json(data)
|
|
341
639
|
else:
|
|
342
|
-
|
|
640
|
+
return models.NonStreamingCompleteResponse.from_json(data)
|
|
641
|
+
except Exception:
|
|
642
|
+
raise ValueError(
|
|
643
|
+
f"Failed to parse Snowflake event data: {data}. "
|
|
644
|
+
"Please report this error here: https://github.com/posit-dev/chatlas/issues/new"
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
# At the time writing, .events() flat out errors in the stream=False case since
|
|
649
|
+
# the Content-Type is set to application/json;charset=utf-8, and SSEClient
|
|
650
|
+
# doesn't know how to handle that.
|
|
651
|
+
# https://github.com/snowflakedb/snowflake-ml-python/blob/6910e96/snowflake/cortex/_sse_client.py#L69
|
|
652
|
+
#
|
|
653
|
+
# So, do some janky stuff here to get the data out of the response.
|
|
654
|
+
#
|
|
655
|
+
# If and when snowflake fixes this, we can remove the try/except block.
|
|
656
|
+
def parse_request_object(
|
|
657
|
+
client: "SSEClient",
|
|
658
|
+
) -> "Optional[models.NonStreamingCompleteResponse]":
|
|
659
|
+
try:
|
|
660
|
+
import urllib3
|
|
661
|
+
|
|
662
|
+
if isinstance(client._event_source, urllib3.response.HTTPResponse):
|
|
663
|
+
return parse_event_data(
|
|
664
|
+
client._event_source.data.decode("utf-8"),
|
|
665
|
+
stream=False,
|
|
666
|
+
)
|
|
667
|
+
except Exception:
|
|
668
|
+
pass
|
|
343
669
|
|
|
344
|
-
|
|
670
|
+
return None
|
chatlas/_tools.py
CHANGED
|
@@ -8,7 +8,10 @@ from pydantic import BaseModel, Field, create_model
|
|
|
8
8
|
|
|
9
9
|
from . import _utils
|
|
10
10
|
|
|
11
|
-
__all__ = (
|
|
11
|
+
__all__ = (
|
|
12
|
+
"Tool",
|
|
13
|
+
"ToolRejectError",
|
|
14
|
+
)
|
|
12
15
|
|
|
13
16
|
if TYPE_CHECKING:
|
|
14
17
|
from openai.types.chat import ChatCompletionToolParam
|
|
@@ -47,6 +50,61 @@ class Tool:
|
|
|
47
50
|
self.name = self.schema["function"]["name"]
|
|
48
51
|
|
|
49
52
|
|
|
53
|
+
class ToolRejectError(Exception):
|
|
54
|
+
"""
|
|
55
|
+
Error to represent a tool call being rejected.
|
|
56
|
+
|
|
57
|
+
This error is meant to be raised when an end user has chosen to deny a tool
|
|
58
|
+
call. It can be raised in a tool function or in a `.on_tool_request()`
|
|
59
|
+
callback registered via a :class:`~chatlas.Chat`. When used in the callback,
|
|
60
|
+
the tool call is rejected before the tool function is invoked.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
reason
|
|
65
|
+
A string describing the reason for rejecting the tool call. This will be
|
|
66
|
+
included in the error message passed to the LLM. In addition to the
|
|
67
|
+
reason, the error message will also include "Tool call rejected." to
|
|
68
|
+
indicate that the tool call was not processed.
|
|
69
|
+
|
|
70
|
+
Raises
|
|
71
|
+
-------
|
|
72
|
+
ToolRejectError
|
|
73
|
+
An error with a message informing the LLM that the tool call was
|
|
74
|
+
rejected (and the reason why).
|
|
75
|
+
|
|
76
|
+
Examples
|
|
77
|
+
--------
|
|
78
|
+
>>> import os
|
|
79
|
+
>>> import chatlas as ctl
|
|
80
|
+
>>>
|
|
81
|
+
>>> chat = ctl.ChatOpenAI()
|
|
82
|
+
>>>
|
|
83
|
+
>>> def list_files():
|
|
84
|
+
... "List files in the user's current directory"
|
|
85
|
+
... while True:
|
|
86
|
+
... allow = input(
|
|
87
|
+
... "Would you like to allow access to your current directory? (yes/no): "
|
|
88
|
+
... )
|
|
89
|
+
... if allow.lower() == "yes":
|
|
90
|
+
... return os.listdir(".")
|
|
91
|
+
... elif allow.lower() == "no":
|
|
92
|
+
... raise ctl.ToolRejectError(
|
|
93
|
+
... "The user has chosen to disallow the tool call."
|
|
94
|
+
... )
|
|
95
|
+
... else:
|
|
96
|
+
... print("Please answer with 'yes' or 'no'.")
|
|
97
|
+
>>>
|
|
98
|
+
>>> chat.register_tool(list_files)
|
|
99
|
+
>>> chat.chat("What files are available in my current directory?")
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(self, reason: str = "The user has chosen to disallow the tool call."):
|
|
103
|
+
message = f"Tool call rejected. {reason}"
|
|
104
|
+
super().__init__(message)
|
|
105
|
+
self.message = message
|
|
106
|
+
|
|
107
|
+
|
|
50
108
|
def func_to_schema(
|
|
51
109
|
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
|
|
52
110
|
model: Optional[type[BaseModel]] = None,
|
chatlas/_version.py
CHANGED