chatlas 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chatlas might be problematic. Click here for more details.
- chatlas/__init__.py +5 -0
- chatlas/_anthropic.py +12 -6
- chatlas/_auto.py +7 -3
- chatlas/_chat.py +339 -120
- chatlas/_content.py +230 -32
- chatlas/_databricks.py +145 -0
- chatlas/_display.py +13 -7
- chatlas/_google.py +9 -5
- chatlas/_ollama.py +3 -2
- chatlas/_openai.py +9 -8
- chatlas/_snowflake.py +46 -23
- chatlas/_utils.py +36 -1
- chatlas/_version.py +2 -2
- chatlas/types/openai/_submit.py +11 -1
- {chatlas-0.6.1.dist-info → chatlas-0.7.0.dist-info}/METADATA +8 -1
- {chatlas-0.6.1.dist-info → chatlas-0.7.0.dist-info}/RECORD +17 -16
- {chatlas-0.6.1.dist-info → chatlas-0.7.0.dist-info}/WHEEL +0 -0
chatlas/_content.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import textwrap
|
|
4
4
|
from pprint import pformat
|
|
5
|
-
from typing import Any, Literal, Optional, Union
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
|
6
6
|
|
|
7
|
+
import orjson
|
|
7
8
|
from pydantic import BaseModel, ConfigDict
|
|
8
9
|
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from htmltools import TagChild
|
|
12
|
+
|
|
9
13
|
ImageContentTypes = Literal[
|
|
10
14
|
"image/png",
|
|
11
15
|
"image/jpeg",
|
|
@@ -174,7 +178,7 @@ class ContentToolRequest(Content):
|
|
|
174
178
|
def __str__(self):
|
|
175
179
|
args_str = self._arguments_str()
|
|
176
180
|
func_call = f"{self.name}({args_str})"
|
|
177
|
-
comment = f"# tool request ({self.id})"
|
|
181
|
+
comment = f"# 🔧 tool request ({self.id})"
|
|
178
182
|
return f"```python\n{comment}\n{func_call}\n```\n"
|
|
179
183
|
|
|
180
184
|
def _repr_markdown_(self):
|
|
@@ -192,49 +196,104 @@ class ContentToolRequest(Content):
|
|
|
192
196
|
return ", ".join(f"{k}={v}" for k, v in self.arguments.items())
|
|
193
197
|
return str(self.arguments)
|
|
194
198
|
|
|
199
|
+
def tagify(self) -> "TagChild":
|
|
200
|
+
"Returns an HTML string suitable for passing to htmltools/shiny's `Chat()` component."
|
|
201
|
+
try:
|
|
202
|
+
from htmltools import HTML, TagList, head_content, tags
|
|
203
|
+
except ImportError:
|
|
204
|
+
raise ImportError(
|
|
205
|
+
".tagify() is only intended to be called by htmltools/shiny, ",
|
|
206
|
+
"but htmltools is not installed. ",
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
html = f"<p></p><span class='chatlas-tool-request'>🔧 Running tool: <code>{self.name}</code></span>"
|
|
210
|
+
|
|
211
|
+
return TagList(
|
|
212
|
+
HTML(html),
|
|
213
|
+
head_content(tags.style(TOOL_CSS)),
|
|
214
|
+
)
|
|
215
|
+
|
|
195
216
|
|
|
196
217
|
class ContentToolResult(Content):
|
|
197
218
|
"""
|
|
198
219
|
The result of calling a tool/function
|
|
199
220
|
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
221
|
+
A content type representing the result of a tool function call. When a model
|
|
222
|
+
requests a tool function, [](`~chatlas.Chat`) will create, (optionally)
|
|
223
|
+
echo, (optionally) yield, and store this content type in the chat history.
|
|
224
|
+
|
|
225
|
+
A tool function may also construct an instance of this class and return it.
|
|
226
|
+
This is useful for a tool that wishes to customize how the result is handled
|
|
227
|
+
(e.g., the format of the value sent to the model).
|
|
203
228
|
|
|
204
229
|
Parameters
|
|
205
230
|
----------
|
|
206
|
-
id
|
|
207
|
-
The unique identifier of the tool request.
|
|
208
231
|
value
|
|
209
|
-
The value
|
|
210
|
-
|
|
211
|
-
The
|
|
232
|
+
The return value of the tool/function.
|
|
233
|
+
model_format
|
|
234
|
+
The format used for sending the value to the model. The default,
|
|
235
|
+
`"auto"`, first attempts to format the value as a JSON string. If that
|
|
236
|
+
fails, it gets converted to a string via `str()`. To force
|
|
237
|
+
`orjson.dumps()` or `str()`, set to `"json"` or `"str"`. Finally,
|
|
238
|
+
`"as_is"` is useful for doing your own formatting and/or passing a
|
|
239
|
+
non-string value (e.g., a list or dict) straight to the model.
|
|
240
|
+
Non-string values are useful for tools that return images or other
|
|
241
|
+
'known' non-text content types.
|
|
212
242
|
error
|
|
213
|
-
An
|
|
243
|
+
An exception that occurred while invoking the tool. If this is set, the
|
|
244
|
+
error message sent to the model and the value is ignored.
|
|
245
|
+
extra
|
|
246
|
+
Additional data associated with the tool result that isn't sent to the
|
|
247
|
+
model.
|
|
248
|
+
request
|
|
249
|
+
Not intended to be used directly. It will be set when the
|
|
250
|
+
:class:`~chatlas.Chat` invokes the tool.
|
|
251
|
+
|
|
252
|
+
Note
|
|
253
|
+
----
|
|
254
|
+
When `model_format` is `"json"` (or `"auto"`), and the value has a
|
|
255
|
+
`.to_json()`/`.to_dict()` method, those methods are called to obtain the
|
|
256
|
+
JSON representation of the value. This is convenient for classes, like
|
|
257
|
+
`pandas.DataFrame`, that have a `.to_json()` method, but don't necessarily
|
|
258
|
+
dump to JSON directly. If this happens to not be the desired behavior, set
|
|
259
|
+
`model_format="as_is"` return the desired value as-is.
|
|
214
260
|
"""
|
|
215
261
|
|
|
216
|
-
|
|
217
|
-
value: Any
|
|
218
|
-
|
|
219
|
-
error: Optional[
|
|
262
|
+
# public
|
|
263
|
+
value: Any
|
|
264
|
+
model_format: Literal["auto", "json", "str", "as_is"] = "auto"
|
|
265
|
+
error: Optional[Exception] = None
|
|
266
|
+
extra: Any = None
|
|
220
267
|
|
|
268
|
+
# "private"
|
|
269
|
+
request: Optional[ContentToolRequest] = None
|
|
221
270
|
content_type: ContentTypeEnum = "tool_result"
|
|
222
271
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
272
|
+
@property
|
|
273
|
+
def id(self):
|
|
274
|
+
if not self.request:
|
|
275
|
+
raise ValueError("id is only available after the tool has been called")
|
|
276
|
+
return self.request.id
|
|
277
|
+
|
|
278
|
+
@property
|
|
279
|
+
def name(self):
|
|
280
|
+
if not self.request:
|
|
281
|
+
raise ValueError("name is only available after the tool has been called")
|
|
282
|
+
return self.request.name
|
|
283
|
+
|
|
284
|
+
@property
|
|
285
|
+
def arguments(self):
|
|
286
|
+
if not self.request:
|
|
287
|
+
raise ValueError(
|
|
288
|
+
"arguments is only available after the tool has been called"
|
|
289
|
+
)
|
|
290
|
+
return self.request.arguments
|
|
233
291
|
|
|
234
292
|
# Primarily used for `echo="all"`...
|
|
235
293
|
def __str__(self):
|
|
236
|
-
|
|
237
|
-
|
|
294
|
+
prefix = "✅ tool result" if not self.error else "❌ tool error"
|
|
295
|
+
comment = f"# {prefix} ({self.id})"
|
|
296
|
+
value = self._get_display_value()
|
|
238
297
|
return f"""```python\n{comment}\n{value}\n```"""
|
|
239
298
|
|
|
240
299
|
# ... and for displaying in the notebook
|
|
@@ -248,9 +307,99 @@ class ContentToolResult(Content):
|
|
|
248
307
|
res += f" error='{self.error}'"
|
|
249
308
|
return res + ">"
|
|
250
309
|
|
|
251
|
-
#
|
|
252
|
-
def
|
|
253
|
-
|
|
310
|
+
# Format the value for display purposes
|
|
311
|
+
def _get_display_value(self) -> object:
|
|
312
|
+
if self.error:
|
|
313
|
+
return f"Tool call failed with error: '{self.error}'"
|
|
314
|
+
|
|
315
|
+
val = self.value
|
|
316
|
+
|
|
317
|
+
# If value is already a dict or list, format it directly
|
|
318
|
+
if isinstance(val, (dict, list)):
|
|
319
|
+
return pformat(val, indent=2, sort_dicts=False)
|
|
320
|
+
|
|
321
|
+
# For string values, try to parse as JSON
|
|
322
|
+
if isinstance(val, str):
|
|
323
|
+
try:
|
|
324
|
+
json_val = orjson.loads(val)
|
|
325
|
+
return pformat(json_val, indent=2, sort_dicts=False)
|
|
326
|
+
except orjson.JSONDecodeError:
|
|
327
|
+
# Not valid JSON, return as string
|
|
328
|
+
return val
|
|
329
|
+
|
|
330
|
+
return val
|
|
331
|
+
|
|
332
|
+
def get_model_value(self) -> object:
|
|
333
|
+
"Get the actual value sent to the model."
|
|
334
|
+
|
|
335
|
+
if self.error:
|
|
336
|
+
return f"Tool call failed with error: '{self.error}'"
|
|
337
|
+
|
|
338
|
+
val, mode = (self.value, self.model_format)
|
|
339
|
+
|
|
340
|
+
if isinstance(val, str):
|
|
341
|
+
return val
|
|
342
|
+
|
|
343
|
+
if mode == "auto":
|
|
344
|
+
try:
|
|
345
|
+
return self._to_json(val)
|
|
346
|
+
except Exception:
|
|
347
|
+
return str(val)
|
|
348
|
+
elif mode == "json":
|
|
349
|
+
return self._to_json(val)
|
|
350
|
+
elif mode == "str":
|
|
351
|
+
return str(val)
|
|
352
|
+
elif mode == "as_is":
|
|
353
|
+
return val
|
|
354
|
+
else:
|
|
355
|
+
raise ValueError(f"Unknown format mode: {mode}")
|
|
356
|
+
|
|
357
|
+
@staticmethod
|
|
358
|
+
def _to_json(value: Any) -> object:
|
|
359
|
+
if hasattr(value, "to_json") and callable(value.to_json):
|
|
360
|
+
return value.to_json()
|
|
361
|
+
|
|
362
|
+
if hasattr(value, "to_dict") and callable(value.to_dict):
|
|
363
|
+
value = value.to_dict()
|
|
364
|
+
|
|
365
|
+
return orjson.dumps(value).decode("utf-8")
|
|
366
|
+
|
|
367
|
+
def tagify(self) -> "TagChild":
|
|
368
|
+
"""
|
|
369
|
+
A method for rendering this object via htmltools/shiny.
|
|
370
|
+
"""
|
|
371
|
+
try:
|
|
372
|
+
from htmltools import HTML
|
|
373
|
+
except ImportError:
|
|
374
|
+
raise ImportError(
|
|
375
|
+
".tagify() is only intended to be called by htmltools/shiny, ",
|
|
376
|
+
"but htmltools is not installed. ",
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
if not self.error:
|
|
380
|
+
header = f"View result from <code>{self.name}</code>"
|
|
381
|
+
else:
|
|
382
|
+
header = f"❌ Failed to call tool <code>{self.name}</code>"
|
|
383
|
+
|
|
384
|
+
args = self._arguments_str()
|
|
385
|
+
content = self._get_display_value()
|
|
386
|
+
|
|
387
|
+
return HTML(
|
|
388
|
+
textwrap.dedent(f"""
|
|
389
|
+
<details class="chatlas-tool-result">
|
|
390
|
+
<summary>{header}</summary>
|
|
391
|
+
<div class="chatlas-tool-result-content">
|
|
392
|
+
Result: <p><code>{content}</code></p>
|
|
393
|
+
Arguments: <p><code>{args}</code></p>
|
|
394
|
+
</div>
|
|
395
|
+
</details>
|
|
396
|
+
""")
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
def _arguments_str(self) -> str:
|
|
400
|
+
if isinstance(self.arguments, dict):
|
|
401
|
+
return ", ".join(f"{k}={v}" for k, v in self.arguments.items())
|
|
402
|
+
return str(self.arguments)
|
|
254
403
|
|
|
255
404
|
|
|
256
405
|
class ContentJson(Content):
|
|
@@ -271,7 +420,7 @@ class ContentJson(Content):
|
|
|
271
420
|
content_type: ContentTypeEnum = "json"
|
|
272
421
|
|
|
273
422
|
def __str__(self):
|
|
274
|
-
return
|
|
423
|
+
return orjson.dumps(self.value, option=orjson.OPT_INDENT_2).decode("utf-8")
|
|
275
424
|
|
|
276
425
|
def _repr_markdown_(self):
|
|
277
426
|
return f"""```json\n{self.__str__()}\n```"""
|
|
@@ -345,3 +494,52 @@ def create_content(data: dict[str, Any]) -> ContentUnion:
|
|
|
345
494
|
return ContentPDF.model_validate(data)
|
|
346
495
|
else:
|
|
347
496
|
raise ValueError(f"Unknown content type: {ct}")
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
TOOL_CSS = """
|
|
500
|
+
/* Get dot to appear inline, even when in a paragraph following the request */
|
|
501
|
+
.chatlas-tool-request + p:has(.markdown-stream-dot) {
|
|
502
|
+
display: inline;
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
/* Hide request when anything other than a dot follows it */
|
|
506
|
+
.chatlas-tool-request:not(:has(+ p .markdown-stream-dot)) {
|
|
507
|
+
display: none;
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
.chatlas-tool-request, .chatlas-tool-result {
|
|
511
|
+
font-weight: 300;
|
|
512
|
+
font-size: 0.9rem;
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
.chatlas-tool-result {
|
|
516
|
+
display: inline-block;
|
|
517
|
+
width: 100%;
|
|
518
|
+
margin-bottom: 1rem;
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
.chatlas-tool-result summary {
|
|
522
|
+
list-style: none;
|
|
523
|
+
cursor: pointer;
|
|
524
|
+
}
|
|
525
|
+
|
|
526
|
+
.chatlas-tool-result summary::after {
|
|
527
|
+
content: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='16' height='16' fill='currentColor' class='bi bi-caret-right-fill' viewBox='0 0 16 16'%3E%3Cpath d='m12.14 8.753-5.482 4.796c-.646.566-1.658.106-1.658-.753V3.204a1 1 0 0 1 1.659-.753l5.48 4.796a1 1 0 0 1 0 1.506z'/%3E%3C/svg%3E");
|
|
528
|
+
font-size: 1.15rem;
|
|
529
|
+
margin-left: 0.25rem;
|
|
530
|
+
vertical-align: middle;
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
.chatlas-tool-result[open] summary::after {
|
|
534
|
+
content: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='16' height='16' fill='currentColor' class='bi bi-caret-down-fill' viewBox='0 0 16 16'%3E%3Cpath d='M7.247 11.14 2.451 5.658C1.885 5.013 2.345 4 3.204 4h9.592a1 1 0 0 1 .753 1.659l-4.796 5.48a1 1 0 0 1-1.506 0z'/%3E%3C/svg%3E");
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
.chatlas-tool-result-content {
|
|
538
|
+
border: 1px solid var(--bs-border-color, #0066cc);
|
|
539
|
+
width: 100%;
|
|
540
|
+
padding: 1rem;
|
|
541
|
+
border-radius: var(--bs-border-radius, 0.2rem);
|
|
542
|
+
margin-top: 1rem;
|
|
543
|
+
margin-bottom: 1rem;
|
|
544
|
+
}
|
|
545
|
+
"""
|
chatlas/_databricks.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
|
|
5
|
+
from ._chat import Chat
|
|
6
|
+
from ._logging import log_model_default
|
|
7
|
+
from ._openai import OpenAIProvider
|
|
8
|
+
from ._turn import Turn, normalize_turns
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from databricks.sdk import WorkspaceClient
|
|
12
|
+
|
|
13
|
+
from ._openai import ChatCompletion
|
|
14
|
+
from .types.openai import SubmitInputArgs
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def ChatDatabricks(
|
|
18
|
+
*,
|
|
19
|
+
system_prompt: Optional[str] = None,
|
|
20
|
+
model: Optional[str] = None,
|
|
21
|
+
turns: Optional[list[Turn]] = None,
|
|
22
|
+
workspace_client: Optional["WorkspaceClient"] = None,
|
|
23
|
+
) -> Chat["SubmitInputArgs", ChatCompletion]:
|
|
24
|
+
"""
|
|
25
|
+
Chat with a model hosted on Databricks.
|
|
26
|
+
|
|
27
|
+
Databricks provides out-of-the-box access to a number of [foundation
|
|
28
|
+
models](https://docs.databricks.com/en/machine-learning/model-serving/score-foundation-models.html)
|
|
29
|
+
and can also serve as a gateway for external models hosted by a third party.
|
|
30
|
+
|
|
31
|
+
Prerequisites
|
|
32
|
+
--------------
|
|
33
|
+
|
|
34
|
+
::: {.callout-note}
|
|
35
|
+
## Python requirements
|
|
36
|
+
|
|
37
|
+
`ChatDatabricks` requires the `databricks-sdk` package: `pip install
|
|
38
|
+
"chatlas[databricks]"`.
|
|
39
|
+
:::
|
|
40
|
+
|
|
41
|
+
::: {.callout-note}
|
|
42
|
+
## Authentication
|
|
43
|
+
|
|
44
|
+
`chatlas` delegates to the `databricks-sdk` package for authentication with
|
|
45
|
+
Databricks. As such, you can use any of the authentication methods discussed
|
|
46
|
+
here:
|
|
47
|
+
|
|
48
|
+
https://docs.databricks.com/aws/en/dev-tools/sdk-python#authentication
|
|
49
|
+
|
|
50
|
+
Note that Python-specific article points to this language-agnostic "unified"
|
|
51
|
+
approach to authentication:
|
|
52
|
+
|
|
53
|
+
https://docs.databricks.com/aws/en/dev-tools/auth/unified-auth
|
|
54
|
+
|
|
55
|
+
There, you'll find all the options listed, but a simple approach that
|
|
56
|
+
generally works well is to set the following environment variables:
|
|
57
|
+
|
|
58
|
+
* `DATABRICKS_HOST`: The Databricks host URL for either the Databricks
|
|
59
|
+
workspace endpoint or the Databricks accounts endpoint.
|
|
60
|
+
* `DATABRICKS_TOKEN`: The Databricks personal access token.
|
|
61
|
+
:::
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
system_prompt
|
|
66
|
+
A system prompt to set the behavior of the assistant.
|
|
67
|
+
model
|
|
68
|
+
The model to use for the chat. The default, None, will pick a reasonable
|
|
69
|
+
default, and warn you about it. We strongly recommend explicitly
|
|
70
|
+
choosing a model for all but the most casual use.
|
|
71
|
+
turns
|
|
72
|
+
A list of turns to start the chat with (i.e., continuing a previous
|
|
73
|
+
conversation). If not provided, the conversation begins from scratch. Do
|
|
74
|
+
not provide non-`None` values for both `turns` and `system_prompt`. Each
|
|
75
|
+
message in the list should be a dictionary with at least `role` (usually
|
|
76
|
+
`system`, `user`, or `assistant`, but `tool` is also possible). Normally
|
|
77
|
+
there is also a `content` field, which is a string.
|
|
78
|
+
workspace_client
|
|
79
|
+
A `databricks.sdk.WorkspaceClient()` to use for the connection. If not
|
|
80
|
+
provided, a new client will be created.
|
|
81
|
+
|
|
82
|
+
Returns
|
|
83
|
+
-------
|
|
84
|
+
Chat
|
|
85
|
+
A chat object that retains the state of the conversation.
|
|
86
|
+
"""
|
|
87
|
+
if model is None:
|
|
88
|
+
model = log_model_default("databricks-dbrx-instruct")
|
|
89
|
+
|
|
90
|
+
return Chat(
|
|
91
|
+
provider=DatabricksProvider(
|
|
92
|
+
model=model,
|
|
93
|
+
workspace_client=workspace_client,
|
|
94
|
+
),
|
|
95
|
+
turns=normalize_turns(
|
|
96
|
+
turns or [],
|
|
97
|
+
system_prompt,
|
|
98
|
+
),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class DatabricksProvider(OpenAIProvider):
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
*,
|
|
106
|
+
model: str,
|
|
107
|
+
workspace_client: Optional["WorkspaceClient"] = None,
|
|
108
|
+
):
|
|
109
|
+
try:
|
|
110
|
+
from databricks.sdk import WorkspaceClient
|
|
111
|
+
except ImportError:
|
|
112
|
+
raise ImportError(
|
|
113
|
+
"`ChatDatabricks()` requires the `databricks-sdk` package. "
|
|
114
|
+
"Install it with `pip install databricks-sdk[openai]`."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
try:
|
|
118
|
+
import httpx
|
|
119
|
+
from openai import AsyncOpenAI
|
|
120
|
+
except ImportError:
|
|
121
|
+
raise ImportError(
|
|
122
|
+
"`ChatDatabricks()` requires the `openai` package. "
|
|
123
|
+
"Install it with `pip install openai`."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
self._model = model
|
|
127
|
+
self._seed = None
|
|
128
|
+
|
|
129
|
+
if workspace_client is None:
|
|
130
|
+
workspace_client = WorkspaceClient()
|
|
131
|
+
|
|
132
|
+
client = workspace_client.serving_endpoints.get_open_ai_client()
|
|
133
|
+
|
|
134
|
+
self._client = client
|
|
135
|
+
|
|
136
|
+
# The databricks sdk does currently expose an async client, but we can
|
|
137
|
+
# effectively mirror what .get_open_ai_client() does internally.
|
|
138
|
+
# Note also there is a open PR to add async support that does essentially
|
|
139
|
+
# the same thing:
|
|
140
|
+
# https://github.com/databricks/databricks-sdk-py/pull/851
|
|
141
|
+
self._async_client = AsyncOpenAI(
|
|
142
|
+
base_url=client.base_url,
|
|
143
|
+
api_key="no-token", # A placeholder to pass validations, this will not be used
|
|
144
|
+
http_client=httpx.AsyncClient(auth=client._client.auth),
|
|
145
|
+
)
|
chatlas/_display.py
CHANGED
|
@@ -12,8 +12,14 @@ from ._typing_extensions import TypedDict
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class MarkdownDisplay(ABC):
|
|
15
|
+
"""Base class for displaying markdown content in different environments."""
|
|
16
|
+
|
|
15
17
|
@abstractmethod
|
|
16
|
-
def
|
|
18
|
+
def echo(self, content: str):
|
|
19
|
+
"""
|
|
20
|
+
Display the provided markdown string. This will append the content
|
|
21
|
+
to the current display.
|
|
22
|
+
"""
|
|
17
23
|
pass
|
|
18
24
|
|
|
19
25
|
@abstractmethod
|
|
@@ -26,7 +32,7 @@ class MarkdownDisplay(ABC):
|
|
|
26
32
|
|
|
27
33
|
|
|
28
34
|
class MockMarkdownDisplay(MarkdownDisplay):
|
|
29
|
-
def
|
|
35
|
+
def echo(self, content: str):
|
|
30
36
|
pass
|
|
31
37
|
|
|
32
38
|
def __enter__(self):
|
|
@@ -41,7 +47,7 @@ class LiveMarkdownDisplay(MarkdownDisplay):
|
|
|
41
47
|
Stream chunks of markdown into a rich-based live updating console.
|
|
42
48
|
"""
|
|
43
49
|
|
|
44
|
-
def __init__(self, echo_options: "
|
|
50
|
+
def __init__(self, echo_options: "EchoDisplayOptions"):
|
|
45
51
|
from rich.console import Console
|
|
46
52
|
|
|
47
53
|
self.content: str = ""
|
|
@@ -63,7 +69,7 @@ class LiveMarkdownDisplay(MarkdownDisplay):
|
|
|
63
69
|
|
|
64
70
|
self._markdown_options = echo_options["rich_markdown"]
|
|
65
71
|
|
|
66
|
-
def
|
|
72
|
+
def echo(self, content: str):
|
|
67
73
|
from rich.markdown import Markdown
|
|
68
74
|
|
|
69
75
|
self.content += content
|
|
@@ -98,11 +104,11 @@ class IPyMarkdownDisplay(MarkdownDisplay):
|
|
|
98
104
|
Stream chunks of markdown into an IPython notebook.
|
|
99
105
|
"""
|
|
100
106
|
|
|
101
|
-
def __init__(self, echo_options: "
|
|
107
|
+
def __init__(self, echo_options: "EchoDisplayOptions"):
|
|
102
108
|
self.content: str = ""
|
|
103
109
|
self._css_styles = echo_options["css_styles"]
|
|
104
110
|
|
|
105
|
-
def
|
|
111
|
+
def echo(self, content: str):
|
|
106
112
|
from IPython.display import Markdown, update_display
|
|
107
113
|
|
|
108
114
|
self.content += content
|
|
@@ -143,7 +149,7 @@ class IPyMarkdownDisplay(MarkdownDisplay):
|
|
|
143
149
|
self._ipy_display_id = None
|
|
144
150
|
|
|
145
151
|
|
|
146
|
-
class
|
|
152
|
+
class EchoDisplayOptions(TypedDict):
|
|
147
153
|
rich_markdown: dict[str, Any]
|
|
148
154
|
rich_console: dict[str, Any]
|
|
149
155
|
css_styles: dict[str, str]
|
chatlas/_google.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import json
|
|
5
4
|
from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload
|
|
6
5
|
|
|
6
|
+
import orjson
|
|
7
7
|
from pydantic import BaseModel
|
|
8
8
|
|
|
9
9
|
from ._chat import Chat
|
|
@@ -432,7 +432,7 @@ class GoogleProvider(
|
|
|
432
432
|
if content.error:
|
|
433
433
|
resp = {"error": content.error}
|
|
434
434
|
else:
|
|
435
|
-
resp = {"result":
|
|
435
|
+
resp = {"result": content.get_model_value()}
|
|
436
436
|
return Part(
|
|
437
437
|
# TODO: seems function response parts might need role='tool'???
|
|
438
438
|
# https://github.com/googleapis/python-genai/blame/c8cfef85c/README.md#L344
|
|
@@ -470,7 +470,7 @@ class GoogleProvider(
|
|
|
470
470
|
text = part.get("text")
|
|
471
471
|
if text:
|
|
472
472
|
if has_data_model:
|
|
473
|
-
contents.append(ContentJson(value=
|
|
473
|
+
contents.append(ContentJson(value=orjson.loads(text)))
|
|
474
474
|
else:
|
|
475
475
|
contents.append(ContentText(text=text))
|
|
476
476
|
function_call = part.get("function_call")
|
|
@@ -492,9 +492,13 @@ class GoogleProvider(
|
|
|
492
492
|
if name:
|
|
493
493
|
contents.append(
|
|
494
494
|
ContentToolResult(
|
|
495
|
-
id=function_response.get("id") or name,
|
|
496
495
|
value=function_response.get("response"),
|
|
497
|
-
|
|
496
|
+
request=ContentToolRequest(
|
|
497
|
+
id=function_response.get("id") or name,
|
|
498
|
+
name=name,
|
|
499
|
+
# TODO: how to get the arguments?
|
|
500
|
+
arguments={},
|
|
501
|
+
),
|
|
498
502
|
)
|
|
499
503
|
)
|
|
500
504
|
|
chatlas/_ollama.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import json
|
|
4
3
|
import re
|
|
5
4
|
import urllib.request
|
|
6
5
|
from typing import TYPE_CHECKING, Optional
|
|
7
6
|
|
|
7
|
+
import orjson
|
|
8
|
+
|
|
8
9
|
from ._chat import Chat
|
|
9
10
|
from ._openai import ChatOpenAI
|
|
10
11
|
from ._turn import Turn
|
|
@@ -121,7 +122,7 @@ def ChatOllama(
|
|
|
121
122
|
|
|
122
123
|
def ollama_models(base_url: str) -> list[str]:
|
|
123
124
|
res = urllib.request.urlopen(url=f"{base_url}/api/tags")
|
|
124
|
-
data =
|
|
125
|
+
data = orjson.loads(res.read())
|
|
125
126
|
return [re.sub(":latest$", "", x["name"]) for x in data["models"]]
|
|
126
127
|
|
|
127
128
|
|
chatlas/_openai.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import json
|
|
5
4
|
from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload
|
|
6
5
|
|
|
6
|
+
import orjson
|
|
7
7
|
from pydantic import BaseModel
|
|
8
8
|
|
|
9
9
|
from ._chat import Chat
|
|
@@ -325,7 +325,8 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|
|
325
325
|
del kwargs_full["tools"]
|
|
326
326
|
|
|
327
327
|
if stream and "stream_options" not in kwargs_full:
|
|
328
|
-
|
|
328
|
+
if self.__class__.__name__ != "DatabricksProvider":
|
|
329
|
+
kwargs_full["stream_options"] = {"include_usage": True}
|
|
329
330
|
|
|
330
331
|
return kwargs_full
|
|
331
332
|
|
|
@@ -432,7 +433,7 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|
|
432
433
|
"id": x.id,
|
|
433
434
|
"function": {
|
|
434
435
|
"name": x.name,
|
|
435
|
-
"arguments":
|
|
436
|
+
"arguments": orjson.dumps(x.arguments).decode("utf-8"),
|
|
436
437
|
},
|
|
437
438
|
"type": "function",
|
|
438
439
|
}
|
|
@@ -498,8 +499,8 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|
|
498
499
|
elif isinstance(x, ContentToolResult):
|
|
499
500
|
tool_results.append(
|
|
500
501
|
ChatCompletionToolMessageParam(
|
|
501
|
-
#
|
|
502
|
-
content=x.
|
|
502
|
+
# Currently, OpenAI only allows for text content in tool results
|
|
503
|
+
content=cast(str, x.get_model_value()),
|
|
503
504
|
tool_call_id=x.id,
|
|
504
505
|
role="tool",
|
|
505
506
|
)
|
|
@@ -528,7 +529,7 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|
|
528
529
|
contents: list[Content] = []
|
|
529
530
|
if message.content is not None:
|
|
530
531
|
if has_data_model:
|
|
531
|
-
data =
|
|
532
|
+
data = orjson.loads(message.content)
|
|
532
533
|
contents = [ContentJson(value=data)]
|
|
533
534
|
else:
|
|
534
535
|
contents = [ContentText(text=message.content)]
|
|
@@ -543,8 +544,8 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|
|
543
544
|
|
|
544
545
|
args = {}
|
|
545
546
|
try:
|
|
546
|
-
args =
|
|
547
|
-
except
|
|
547
|
+
args = orjson.loads(func.arguments) if func.arguments else {}
|
|
548
|
+
except orjson.JSONDecodeError:
|
|
548
549
|
raise ValueError(
|
|
549
550
|
f"The model's completion included a tool request ({func.name}) "
|
|
550
551
|
"with invalid JSON for input arguments: '{func.arguments}'"
|